]> vgcfreebox.myrthtech.pt Git - ue-rnap-aerossol.git/blob - torch-randomdata-example-MNIST.py
code executed in ue niia server
[ue-rnap-aerossol.git] / torch-randomdata-example-MNIST.py
1 # ----------------------------------------------
2 # 💡 Cell 1: Setup and Imports
3 # ----------------------------------------------
4
5 # 1. Import Core Libraries
6 import torch
7 import torch.nn as nn
8 import torch.optim as optim
9 from torch.utils.data import DataLoader
10 from torchvision import datasets, transforms
11 import matplotlib.pyplot as plt
12 import numpy as np
13
14 # 2. Device Configuration (The most crucial check!)
15 # This automatically detects and selects the GPU if available.
16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 print(f"✅ Successfully initialized. Using device: {device}")
18
19 # Optional: Check GPU details (Good for debugging)
20 if device.type == 'cuda':
21 print(f"GPU Name: {torch.cuda.get_device_name(0)}")
22
23 # ----------------------------------------------
24 # 💡 Cell 2: Data Loading and Transformations
25 # ----------------------------------------------
26
27 # Define the preprocessing steps
28 transform = transforms.Compose([
29 transforms.ToTensor(), # Converts the image to a Tensor
30 transforms.Normalize((0.5,), (0.5,)) # Normalizes pixel values (0 to 1)
31 ])
32
33 # Download and Load the Dataset (Train set)
34 train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
35
36 # Create the DataLoader (handles batching and shuffling)
37 BATCH_SIZE = 64
38 train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
39
40 # Repeat for the Test/Validation set
41 test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
42 test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)
43
44 print("✅ Data loaded and prepared successfully!")
45
46 # ----------------------------------------------
47 # 💡 Cell 3: Model Definition and GPU Transfer
48 # ----------------------------------------------
49
50 # Define the CNN Model Architecture
51 class SimpleCNN(nn.Module):
52 def __init__(self):
53 super(SimpleCNN, self).__init__()
54 # Convolutional layer: 1 channel in, 16 channels out, 3x3 kernel
55 self.conv1 = nn.Sequential(
56 nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
57 nn.ReLU(),
58 nn.MaxPool2d(kernel_size=2) # Halves the spatial dimensions
59 )
60 # Second convolutional layer
61 self.conv2 = nn.Sequential(
62 nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
63 nn.ReLU(),
64 nn.MaxPool2d(kernel_size=2)
65 )
66 # Fully connected layer (Calculate input size: 32 channels * 7 * 7)
67 self.fc = nn.Linear(32 * 7 * 7, 10)
68
69 def forward(self, x):
70 x = self.conv1(x)
71 x = self.conv2(x)
72 # Flatten the tensor for the linear layer
73 x = x.view(x.size(0), -1)
74 x = self.fc(x)
75 return x
76
77 # Initialize the model
78 model = SimpleCNN()
79
80 # CRITICAL STEP: Move the entire model's parameters to the GPU
81 model.to(device)
82
83 print("✅ Model defined and weights transferred to the GPU!")
84
85
86 # ----------------------------------------------
87 # 💡 Cell 4: The Training Loop
88 # ----------------------------------------------
89
90 # Setup hyperparameters
91 NUM_EPOCHS = 10
92 LEARNING_RATE = 0.001
93
94 # Loss function and Optimizer
95 criterion = nn.CrossEntropyLoss()
96 optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
97
98 # Store history for plotting
99 loss_history = []
100 acc_history = []
101
102 print("🚀 Starting Training...")
103
104 for epoch in range(NUM_EPOCHS):
105 # Set model to training mode
106 model.train()
107 total_loss = 0
108
109 for batch_idx, data in enumerate(train_loader):
110
111 # CRITICAL: Move data (inputs and labels) to the GPU!
112 images = data[0].to(device)
113 labels = data[1].to(device)
114
115 # 1. Zero Gradients
116 optimizer.zero_grad()
117
118 # 2. Forward Pass
119 outputs = model(images)
120
121 # 3. Calculate Loss
122 loss = criterion(outputs, labels)
123
124 # 4. Backward Pass (The CUDA magic happens here)
125 # PyTorch automatically handles the graph computation on the GPU.
126 loss.backward()
127
128 # 5. Optimize (Updates the weights)
129 optimizer.step()
130
131 total_loss += loss.item()
132
133 # Calculate average loss for the epoch
134 avg_loss = total_loss / len(train_loader)
135 loss_history.append(avg_loss)
136 print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Loss: {avg_loss:.4f}")
137
138 print("🎉 Training Complete!")
139
140
141 # ----------------------------------------------
142 # 💡 Cell 5: Testing and Visualization
143 # ----------------------------------------------
144
145 # Set model to evaluation mode (disables dropout, etc.)
146 model.eval()
147 correct = 0
148 total = 0
149
150 with torch.no_grad(): # Context manager that disables gradient tracking (saves memory)
151 for data in test_loader:
152 # CRITICAL: Move data to the GPU
153 images = data[0].to(device)
154 labels = data[1].to(device)
155
156 outputs = model(images)
157
158 # Get the index of the highest score (the predicted class)
159 _, predicted = torch.max(outputs.data, 1)
160
161 total += labels.size(0)
162 correct += (predicted.eq(labels.view_as(predicted))).sum().item()
163
164 accuracy = 100 * correct / total
165 print(f"\n🌟 Final Test Accuracy: {accuracy:.2f}%")
166
167
168 # Visualization (Highly recommended in a Jupyter environment)
169 plt.figure(figsize=(12, 5))
170
171 # Plot 1: Loss Curve
172 plt.subplot(1, 2, 1)
173 plt.plot(loss_history, marker='o')
174 plt.title("Training Loss Over Epochs")
175 plt.xlabel("Epoch")
176 plt.ylabel("Loss")
177
178 # Plot 2: Conceptual Improvement (You would track accuracy here)
179 plt.subplot(1, 2, 2)
180 plt.plot([0] * len(loss_history), label="Dummy Acc.") # Placeholder for accuracy plot
181 plt.title("Model Performance")
182 plt.xlabel("Epoch")
183 plt.ylabel("Accuracy (%)")
184
185 plt.tight_layout()
186 plt.show()
187
188 # Optional: Save the best model weights
189 torch.save(model.state_dict(), 'mnist_cnn_model.pth')
190 print("\nModel weights saved to 'mnist_cnn_model.pth'")