import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np


# ----------------------------------- BUILD MODEL ---------------------------------------------
# 1. Define the Model Class
class SimpleDNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleDNN, self).__init__()
        # Define the layers (Linear means fully connected)
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU() # Activation function
        self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # This is the path data takes through the network
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x

# Initialization: Assuming input is 784 (like MNIST flattened image)
INPUT_SIZE = 784
HIDDEN_SIZE = 128
OUTPUT_SIZE = 10 # 10 classes
model = SimpleDNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE)


# ----------------------------------- GPU INTEGRATION ---------------------------------------------
# 1. Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 2. Move the entire model to the GPU
model.to(device) 

# Setup
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
criterion = nn.CrossEntropyLoss() # Loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Optimizer


# ----------------------------------- LOAD DATA ---------------------------------------------
# 1. DEFINE THE CUSTOM DATASET
class CustomDataset(Dataset):
    def __init__(self, features, labels):
        # features should be the full dataset of inputs (e.g., all 784 pixel values)
        self.features = torch.tensor(features, dtype=torch.float32)
        # labels should be the full dataset of target labels (integers)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        # Returns the total number of samples
        return len(self.features)

    def __getitem__(self, idx):
        # Returns a single sample and its label (formatted as a dictionary 
        # to match your current usage: data['features'], data['labels'])
        return {
            'features': self.features[idx],
            'labels': self.labels[idx]
        }

# 2. LOAD THE DATA (REPLACE THIS WITH YOUR ACTUAL LOADING CODE)
DUMMY_FEATURES = np.random.rand(100, INPUT_SIZE).astype(np.float32)
DUMMY_LABELS = np.random.randint(0, OUTPUT_SIZE, 100).astype(np.int64)

# 3. INSTANTIATE AND WRAP THE LOADER
# Create the dataset object
train_dataset = CustomDataset(DUMMY_FEATURES, DUMMY_LABELS)

# Create the DataLoader object
BATCH_SIZE = 64 # Choose a desired batch size
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


# ----------------------------------- TRAINING ---------------------------------------------
# --- The Training Loop ---
for epoch in range(NUM_EPOCHS):
    for batch_idx, data in enumerate(train_loader):
        
        # 1. MOVE DATA TO GPU
        inputs = data['features'].to(device)
        labels = data['labels'].to(device)

        # 2. ZERO GRADIENTS (Crucial step!)
        # Must clear the gradients from the previous step
        optimizer.zero_grad() 

        # 3. FORWARD PASS
        outputs = model(inputs)
        
        # 4. CALCULATE LOSS
        loss = criterion(outputs, labels)
        
        # 5. BACKWARD PASS (Calculates gradients)
        # This is the step that utilizes CUDA for massive parallel computation.
        loss.backward() 
        
        # 6. OPTIMIZER STEP (Updates weights)
        optimizer.step() 
        
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {loss.item():.4f}")
