import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import tifffile # Replaced cv2 with tifffile for multi-channel support

# ---------------------------------------------------------
# 1. Configuration & Hyperparameters
# ---------------------------------------------------------
TRAIN_CSV = 'data/SENTINEL2-AEROSSOL-DATA/train.csv'
TEST_CSV = 'data/SENTINEL2-AEROSSOL-DATA/test.csv'
IMAGE_DIR = 'data/SENTINEL2-AEROSSOL-DATA/tiff-images/'  # Update this to the folder containing your .tiff images
SUBMISSION_FILE = 'submission.csv'

BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPOCHS = 50
CHANNELS = 20 # Updated to 20 based on the OpenCV error log showing 20 channels
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------------------------------------
# 2. Custom Dataset Definition (Multi-modal)
# ---------------------------------------------------------
class SentinelAeronetDataset(Dataset):
    def __init__(self, csv_file, image_dir, is_test=False, scaler_stats=None):
        """
        Loads both tabular data and multi-channel Sentinel-2 .tiff images.
        """
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.is_test = is_test
        
        # Ensure column names match (handle potential capitalization mismatches)
        col_map = {c.lower(): c for c in self.data.columns}
        self.col_elev = col_map.get('elevation', 'Elevation')
        self.col_ozone = col_map.get('ozone', 'Ozone')
        self.col_no2 = col_map.get('no2', 'NO2')
        self.col_aot = col_map.get('aot', 'AOT')
        if not self.is_test:
            self.data = self.data.dropna(subset=[self.col_aot])

        # FIX: Pandas ChainedAssignmentError (Modern approach to fillna)
        self.data[self.col_elev] = self.data[self.col_elev].fillna(self.data[self.col_elev].median())
        self.data[self.col_ozone] = self.data[self.col_ozone].fillna(self.data[self.col_ozone].median())
        self.data[self.col_no2] = self.data[self.col_no2].fillna(self.data[self.col_no2].median())

        self.img_names = self.data['img_name'].values
        self.X_tab = self.data[[self.col_elev, self.col_ozone, self.col_no2]].values.astype(np.float32)
        
        # Normalization (StandardScaler logic)
        if scaler_stats is None:
            self.mean = self.X_tab.mean(axis=0)
            self.std = self.X_tab.std(axis=0)
        else:
            # FIX: Typo corrected here (was self.self_std)
            self.mean, self.std = scaler_stats 
            
        self.X_tab = (self.X_tab - self.mean) / (self.std + 1e-8)
        
        # Target variable (Only available in train)
        if not self.is_test:
            self.y = self.data[self.col_aot].values.astype(np.float32)

    def get_scaler_stats(self):
        return self.mean, self.std

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 1. Load Tabular features
        tab_features = torch.tensor(self.X_tab[idx])
        
        # 2. Load Image using tifffile
        img_path = os.path.join(self.image_dir, self.img_names[idx])
        
        try:
            # tifffile can read arbitrarily sized matrices (e.g. 20 channels)
            img = tifffile.imread(img_path)
            
            # Basic fallback if image is totally empty
            if img is None:
                img = np.zeros((CHANNELS, 19, 19), dtype=np.float32)
            else:
                img = img.astype(np.float32)
                img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0)
                img = img / 10000.0
                # Check shape. tifffile usually loads as (Height, Width, Channels) or (Channels, Height, Width)
                # PyTorch expects (Channels, Height, Width).
                # If the last dimension is 19 or 20, it's likely (H, W, C) so we transpose.
                if len(img.shape) == 3 and (img.shape[2] == 19 or img.shape[2] == 20):
                    img = img.transpose(2, 0, 1)
                    
        except Exception as e:
            # If the file is missing or corrupted, return a blank tensor so training doesn't crash
            # print(f"Warning: Could not read {img_path}. Error: {e}")
            img = np.zeros((CHANNELS, 19, 19), dtype=np.float32)
        
        img_features = torch.tensor(img)
        
        if self.is_test:
            return img_features, tab_features, self.data['id'].iloc[idx]
        else:
            target = torch.tensor(self.y[idx])
            return img_features, tab_features, target

# ---------------------------------------------------------
# 3. Multi-modal DNN Model Definition
# ---------------------------------------------------------
class MultiModalAOTPredictor(nn.Module):
    def __init__(self, in_channels=CHANNELS):
        super(MultiModalAOTPredictor, self).__init__()
        
        # --- Image Branch (CNN) ---
        # Input: `in_channels` (20), 19x19 pixels
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), # Output: 32 x 9 x 9
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)), # Output: 64 x 1 x 1
            nn.Flatten() # Output: 64
        )
        
        # --- Tabular Branch (MLP) ---
        # Input: 3 features (Elevation, Ozone, NO2)
        self.mlp = nn.Sequential(
            nn.Linear(3, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU()
        )
        
        # Fusion Branch
        # Combines 64 features from CNN + 16 features from MLP = 80 total
        self.fusion = nn.Sequential(
            nn.Linear(64 + 16, 64),
            nn.ReLU(),
            nn.Dropout(0.2), # Dropout to prevent overfitting
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1) # Predict AOT
        )

    def forward(self, img, tab):
        img_out = self.cnn(img)
        tab_out = self.mlp(tab)
        
        # Concatenate features along the batch dimension
        combined = torch.cat((img_out, tab_out), dim=1)
        
        out = self.fusion(combined)
        return out.squeeze()

# ---------------------------------------------------------
# 4. Training Function
# ---------------------------------------------------------
def train_model():
    print(f"Using device: {DEVICE}")
    
    # 1. Setup Datasets
    train_dataset = SentinelAeronetDataset(TRAIN_CSV, IMAGE_DIR, is_test=False)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Save scaler stats so the test set is normalized exactly like the train set
    scaler_stats = train_dataset.get_scaler_stats()
    
    # 2. Initialize Model
    model = MultiModalAOTPredictor(in_channels=CHANNELS).to(DEVICE)
    criterion = nn.MSELoss() 
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    
    # 3. Training Loop
    epoch_losses = []
    print("Starting Training...")
    
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        
        for imgs, tabs, targets in train_loader:
            imgs, tabs, targets = imgs.to(DEVICE), tabs.to(DEVICE), targets.to(DEVICE)
            
            optimizer.zero_grad()
            predictions = model(imgs, tabs)
            
            loss = criterion(predictions, targets)
            loss.backward()
            # --- NEW: Gradient Clipping to prevent "Exploding Gradients" ---
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            running_loss += loss.item() * imgs.size(0)
            
        epoch_loss = running_loss / len(train_dataset)
        epoch_losses.append(epoch_loss)
        
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}], Loss (MSE): {epoch_loss:.6f}")
            
    print("Training complete!")
    torch.save(model.state_dict(), "aot_multimodal_model.pth")
    return model, epoch_losses, scaler_stats

# ---------------------------------------------------------
# 5. Prediction Function for Submission
# ---------------------------------------------------------
def generate_submission(model, scaler_stats):
    print("\nGenerating predictions for test.csv...")
    model.eval()
    
    test_dataset = SentinelAeronetDataset(TEST_CSV, IMAGE_DIR, is_test=True, scaler_stats=scaler_stats)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    ids = []
    predictions = []
    
    with torch.no_grad():
        for imgs, tabs, batch_ids in test_loader:
            imgs, tabs = imgs.to(DEVICE), tabs.to(DEVICE)
            
            preds = model(imgs, tabs)
            
            ids.extend(batch_ids.numpy())
            predictions.extend(preds.cpu().numpy())
            
    # Create submission dataframe
    submission_df = pd.DataFrame({
        'id': ids,
        'AOT': predictions
    })
    
    submission_df.to_csv(SUBMISSION_FILE, index=False)
    print(f"Submission saved to {SUBMISSION_FILE}")

# ---------------------------------------------------------
# 6. Execution
# ---------------------------------------------------------
if __name__ == "__main__":
    if not os.path.exists(IMAGE_DIR):
        print(f"WARNING: Image directory '{IMAGE_DIR}' not found. Please update IMAGE_DIR or place images there.")
    
    if os.path.exists(TRAIN_CSV):
        # 1. Train
        trained_model, losses, scaler_stats = train_model()
        
        # Plot Loss
        plt.figure(figsize=(8, 5))
        plt.plot(range(1, len(losses) + 1), losses, marker='o', color='r', label='Train MSE')
        plt.title("Multimodal AOT Model Training Loss")
        plt.xlabel("Epoch")
        plt.ylabel("MSE")
        plt.grid(True)
        plt.savefig("loss_plot.png")
        print("Saved loss plot to 'loss_plot.png'")
        
        # 2. Predict on Test set
        if os.path.exists(TEST_CSV):
            generate_submission(trained_model, scaler_stats)
        else:
            print(f"Test file '{TEST_CSV}' not found. Skipping submission generation.")
    else:
        print(f"Training file '{TRAIN_CSV}' not found. Please ensure it is in the same directory.")