+import os
+import pandas as pd
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import Dataset, DataLoader
+import matplotlib.pyplot as plt
+import tifffile # Replaced cv2 with tifffile for multi-channel support
+
+# ---------------------------------------------------------
+# 1. Configuration & Hyperparameters
+# ---------------------------------------------------------
+TRAIN_CSV = 'data/SENTINEL2-AEROSSOL-DATA/train.csv'
+TEST_CSV = 'data/SENTINEL2-AEROSSOL-DATA/test.csv'
+IMAGE_DIR = 'data/SENTINEL2-AEROSSOL-DATA/tiff-images/' # Update this to the folder containing your .tiff images
+SUBMISSION_FILE = 'submission.csv'
+
+BATCH_SIZE = 32
+LEARNING_RATE = 0.001
+EPOCHS = 50
+CHANNELS = 20 # Updated to 20 based on the OpenCV error log showing 20 channels
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# ---------------------------------------------------------
+# 2. Custom Dataset Definition (Multi-modal)
+# ---------------------------------------------------------
+class SentinelAeronetDataset(Dataset):
+ def __init__(self, csv_file, image_dir, is_test=False, scaler_stats=None):
+ """
+ Loads both tabular data and multi-channel Sentinel-2 .tiff images.
+ """
+ self.data = pd.read_csv(csv_file)
+ self.image_dir = image_dir
+ self.is_test = is_test
+
+ # Ensure column names match (handle potential capitalization mismatches)
+ col_map = {c.lower(): c for c in self.data.columns}
+ self.col_elev = col_map.get('elevation', 'Elevation')
+ self.col_ozone = col_map.get('ozone', 'Ozone')
+ self.col_no2 = col_map.get('no2', 'NO2')
+ self.col_aot = col_map.get('aot', 'AOT')
+ if not self.is_test:
+ self.data = self.data.dropna(subset=[self.col_aot])
+
+ # FIX: Pandas ChainedAssignmentError (Modern approach to fillna)
+ self.data[self.col_elev] = self.data[self.col_elev].fillna(self.data[self.col_elev].median())
+ self.data[self.col_ozone] = self.data[self.col_ozone].fillna(self.data[self.col_ozone].median())
+ self.data[self.col_no2] = self.data[self.col_no2].fillna(self.data[self.col_no2].median())
+
+ self.img_names = self.data['img_name'].values
+ self.X_tab = self.data[[self.col_elev, self.col_ozone, self.col_no2]].values.astype(np.float32)
+
+ # Normalization (StandardScaler logic)
+ if scaler_stats is None:
+ self.mean = self.X_tab.mean(axis=0)
+ self.std = self.X_tab.std(axis=0)
+ else:
+ # FIX: Typo corrected here (was self.self_std)
+ self.mean, self.std = scaler_stats
+
+ self.X_tab = (self.X_tab - self.mean) / (self.std + 1e-8)
+
+ # Target variable (Only available in train)
+ if not self.is_test:
+ self.y = self.data[self.col_aot].values.astype(np.float32)
+
+ def get_scaler_stats(self):
+ return self.mean, self.std
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ # 1. Load Tabular features
+ tab_features = torch.tensor(self.X_tab[idx])
+
+ # 2. Load Image using tifffile
+ img_path = os.path.join(self.image_dir, self.img_names[idx])
+
+ try:
+ # tifffile can read arbitrarily sized matrices (e.g. 20 channels)
+ img = tifffile.imread(img_path)
+
+ # Basic fallback if image is totally empty
+ if img is None:
+ img = np.zeros((CHANNELS, 19, 19), dtype=np.float32)
+ else:
+ img = img.astype(np.float32)
+ img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0)
+ img = img / 10000.0
+ # Check shape. tifffile usually loads as (Height, Width, Channels) or (Channels, Height, Width)
+ # PyTorch expects (Channels, Height, Width).
+ # If the last dimension is 19 or 20, it's likely (H, W, C) so we transpose.
+ if len(img.shape) == 3 and (img.shape[2] == 19 or img.shape[2] == 20):
+ img = img.transpose(2, 0, 1)
+
+ except Exception as e:
+ # If the file is missing or corrupted, return a blank tensor so training doesn't crash
+ # print(f"Warning: Could not read {img_path}. Error: {e}")
+ img = np.zeros((CHANNELS, 19, 19), dtype=np.float32)
+
+ img_features = torch.tensor(img)
+
+ if self.is_test:
+ return img_features, tab_features, self.data['id'].iloc[idx]
+ else:
+ target = torch.tensor(self.y[idx])
+ return img_features, tab_features, target
+
+# ---------------------------------------------------------
+# 3. Multi-modal DNN Model Definition
+# ---------------------------------------------------------
+class MultiModalAOTPredictor(nn.Module):
+ def __init__(self, in_channels=CHANNELS):
+ super(MultiModalAOTPredictor, self).__init__()
+
+ # --- Image Branch (CNN) ---
+ # Input: `in_channels` (20), 19x19 pixels
+ self.cnn = nn.Sequential(
+ nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1),
+ nn.ReLU(),
+ nn.MaxPool2d(kernel_size=2), # Output: 32 x 9 x 9
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
+ nn.ReLU(),
+ nn.AdaptiveAvgPool2d((1, 1)), # Output: 64 x 1 x 1
+ nn.Flatten() # Output: 64
+ )
+
+ # --- Tabular Branch (MLP) ---
+ # Input: 3 features (Elevation, Ozone, NO2)
+ self.mlp = nn.Sequential(
+ nn.Linear(3, 16),
+ nn.ReLU(),
+ nn.Linear(16, 16),
+ nn.ReLU()
+ )
+
+ # Fusion Branch
+ # Combines 64 features from CNN + 16 features from MLP = 80 total
+ self.fusion = nn.Sequential(
+ nn.Linear(64 + 16, 64),
+ nn.ReLU(),
+ nn.Dropout(0.2), # Dropout to prevent overfitting
+ nn.Linear(64, 32),
+ nn.ReLU(),
+ nn.Linear(32, 1) # Predict AOT
+ )
+
+ def forward(self, img, tab):
+ img_out = self.cnn(img)
+ tab_out = self.mlp(tab)
+
+ # Concatenate features along the batch dimension
+ combined = torch.cat((img_out, tab_out), dim=1)
+
+ out = self.fusion(combined)
+ return out.squeeze()
+
+# ---------------------------------------------------------
+# 4. Training Function
+# ---------------------------------------------------------
+def train_model():
+ print(f"Using device: {DEVICE}")
+
+ # 1. Setup Datasets
+ train_dataset = SentinelAeronetDataset(TRAIN_CSV, IMAGE_DIR, is_test=False)
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
+
+ # Save scaler stats so the test set is normalized exactly like the train set
+ scaler_stats = train_dataset.get_scaler_stats()
+
+ # 2. Initialize Model
+ model = MultiModalAOTPredictor(in_channels=CHANNELS).to(DEVICE)
+ criterion = nn.MSELoss()
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
+
+ # 3. Training Loop
+ epoch_losses = []
+ print("Starting Training...")
+
+ for epoch in range(EPOCHS):
+ model.train()
+ running_loss = 0.0
+
+ for imgs, tabs, targets in train_loader:
+ imgs, tabs, targets = imgs.to(DEVICE), tabs.to(DEVICE), targets.to(DEVICE)
+
+ optimizer.zero_grad()
+ predictions = model(imgs, tabs)
+
+ loss = criterion(predictions, targets)
+ loss.backward()
+ # --- NEW: Gradient Clipping to prevent "Exploding Gradients" ---
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
+ optimizer.step()
+
+ running_loss += loss.item() * imgs.size(0)
+
+ epoch_loss = running_loss / len(train_dataset)
+ epoch_losses.append(epoch_loss)
+
+ if (epoch + 1) % 5 == 0 or epoch == 0:
+ print(f"Epoch [{epoch+1}/{EPOCHS}], Loss (MSE): {epoch_loss:.6f}")
+
+ print("Training complete!")
+ torch.save(model.state_dict(), "aot_multimodal_model.pth")
+ return model, epoch_losses, scaler_stats
+
+# ---------------------------------------------------------
+# 5. Prediction Function for Submission
+# ---------------------------------------------------------
+def generate_submission(model, scaler_stats):
+ print("\nGenerating predictions for test.csv...")
+ model.eval()
+
+ test_dataset = SentinelAeronetDataset(TEST_CSV, IMAGE_DIR, is_test=True, scaler_stats=scaler_stats)
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
+
+ ids = []
+ predictions = []
+
+ with torch.no_grad():
+ for imgs, tabs, batch_ids in test_loader:
+ imgs, tabs = imgs.to(DEVICE), tabs.to(DEVICE)
+
+ preds = model(imgs, tabs)
+
+ ids.extend(batch_ids.numpy())
+ predictions.extend(preds.cpu().numpy())
+
+ # Create submission dataframe
+ submission_df = pd.DataFrame({
+ 'id': ids,
+ 'AOT': predictions
+ })
+
+ submission_df.to_csv(SUBMISSION_FILE, index=False)
+ print(f"Submission saved to {SUBMISSION_FILE}")
+
+# ---------------------------------------------------------
+# 6. Execution
+# ---------------------------------------------------------
+if __name__ == "__main__":
+ if not os.path.exists(IMAGE_DIR):
+ print(f"WARNING: Image directory '{IMAGE_DIR}' not found. Please update IMAGE_DIR or place images there.")
+
+ if os.path.exists(TRAIN_CSV):
+ # 1. Train
+ trained_model, losses, scaler_stats = train_model()
+
+ # Plot Loss
+ plt.figure(figsize=(8, 5))
+ plt.plot(range(1, len(losses) + 1), losses, marker='o', color='r', label='Train MSE')
+ plt.title("Multimodal AOT Model Training Loss")
+ plt.xlabel("Epoch")
+ plt.ylabel("MSE")
+ plt.grid(True)
+ plt.savefig("loss_plot.png")
+ print("Saved loss plot to 'loss_plot.png'")
+
+ # 2. Predict on Test set
+ if os.path.exists(TEST_CSV):
+ generate_submission(trained_model, scaler_stats)
+ else:
+ print(f"Test file '{TEST_CSV}' not found. Skipping submission generation.")
+ else:
+ print(f"Training file '{TRAIN_CSV}' not found. Please ensure it is in the same directory.")
\ No newline at end of file