# ----------------------------------------------
# 💡 Cell 1: Setup and Imports
# ----------------------------------------------

# 1. Import Core Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# 2. Device Configuration (The most crucial check!)
# This automatically detects and selects the GPU if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Successfully initialized. Using device: {device}")

# Optional: Check GPU details (Good for debugging)
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

# ----------------------------------------------
# 💡 Cell 2: Data Loading and Transformations
# ----------------------------------------------

# Define the preprocessing steps
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts the image to a Tensor
    transforms.Normalize((0.5,), (0.5,)) # Normalizes pixel values (0 to 1)
])

# Download and Load the Dataset (Train set)
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create the DataLoader (handles batching and shuffling)
BATCH_SIZE = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Repeat for the Test/Validation set
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("✅ Data loaded and prepared successfully!")

# ----------------------------------------------
# 💡 Cell 3: Model Definition and GPU Transfer
# ----------------------------------------------

# Define the CNN Model Architecture
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Convolutional layer: 1 channel in, 16 channels out, 3x3 kernel
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2) # Halves the spatial dimensions
        )
        # Second convolutional layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        # Fully connected layer (Calculate input size: 32 channels * 7 * 7)
        self.fc = nn.Linear(32 * 7 * 7, 10) 

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # Flatten the tensor for the linear layer
        x = x.view(x.size(0), -1) 
        x = self.fc(x)
        return x

# Initialize the model
model = SimpleCNN()

# CRITICAL STEP: Move the entire model's parameters to the GPU
model.to(device) 

print("✅ Model defined and weights transferred to the GPU!")


# ----------------------------------------------
# 💡 Cell 4: The Training Loop
# ----------------------------------------------

# Setup hyperparameters
NUM_EPOCHS = 10
LEARNING_RATE = 0.001

# Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Store history for plotting
loss_history = []
acc_history = []

print("🚀 Starting Training...")

for epoch in range(NUM_EPOCHS):
    # Set model to training mode
    model.train() 
    total_loss = 0
    
    for batch_idx, data in enumerate(train_loader):

        # CRITICAL: Move data (inputs and labels) to the GPU!
        images = data[0].to(device)
        labels = data[1].to(device)
        
        # 1. Zero Gradients
        optimizer.zero_grad()
        
        # 2. Forward Pass
        outputs = model(images)
        
        # 3. Calculate Loss
        loss = criterion(outputs, labels)
        
        # 4. Backward Pass (The CUDA magic happens here)
        # PyTorch automatically handles the graph computation on the GPU.
        loss.backward() 
        
        # 5. Optimize (Updates the weights)
        optimizer.step()
        
        total_loss += loss.item()
    
    # Calculate average loss for the epoch
    avg_loss = total_loss / len(train_loader)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Loss: {avg_loss:.4f}")

print("🎉 Training Complete!")


# ----------------------------------------------
# 💡 Cell 5: Testing and Visualization
# ----------------------------------------------

# Set model to evaluation mode (disables dropout, etc.)
model.eval() 
correct = 0
total = 0

with torch.no_grad(): # Context manager that disables gradient tracking (saves memory)
    for data in test_loader:
        # CRITICAL: Move data to the GPU
        images = data[0].to(device)
        labels = data[1].to(device)
        
        outputs = model(images)
        
        # Get the index of the highest score (the predicted class)
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted.eq(labels.view_as(predicted))).sum().item()

accuracy = 100 * correct / total
print(f"\n🌟 Final Test Accuracy: {accuracy:.2f}%")


# Visualization (Highly recommended in a Jupyter environment)
plt.figure(figsize=(12, 5))

# Plot 1: Loss Curve
plt.subplot(1, 2, 1)
plt.plot(loss_history, marker='o')
plt.title("Training Loss Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")

# Plot 2: Conceptual Improvement (You would track accuracy here)
plt.subplot(1, 2, 2)
plt.plot([0] * len(loss_history), label="Dummy Acc.") # Placeholder for accuracy plot
plt.title("Model Performance")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")

plt.tight_layout()
plt.show()

# Optional: Save the best model weights
torch.save(model.state_dict(), 'mnist_cnn_model.pth')
print("\nModel weights saved to 'mnist_cnn_model.pth'")
