]> vgcfreebox.myrthtech.pt Git - ue-rnap-aerossol.git/blob - train-aot-multimodal-tifffile.py
code executed in ue niia server
[ue-rnap-aerossol.git] / train-aot-multimodal-tifffile.py
1 import os
2 import pandas as pd
3 import numpy as np
4 import torch
5 import torch.nn as nn
6 import torch.optim as optim
7 from torch.utils.data import Dataset, DataLoader
8 import matplotlib.pyplot as plt
9 import tifffile # Replaced cv2 with tifffile for multi-channel support
10
11 # ---------------------------------------------------------
12 # 1. Configuration & Hyperparameters
13 # ---------------------------------------------------------
14 TRAIN_CSV = 'data/SENTINEL2-AEROSSOL-DATA/train.csv'
15 TEST_CSV = 'data/SENTINEL2-AEROSSOL-DATA/test.csv'
16 IMAGE_DIR = 'data/SENTINEL2-AEROSSOL-DATA/tiff-images/' # Update this to the folder containing your .tiff images
17 SUBMISSION_FILE = 'submission.csv'
18
19 BATCH_SIZE = 32
20 LEARNING_RATE = 0.001
21 EPOCHS = 50
22 CHANNELS = 20 # Updated to 20 based on the OpenCV error log showing 20 channels
23 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
25 # ---------------------------------------------------------
26 # 2. Custom Dataset Definition (Multi-modal)
27 # ---------------------------------------------------------
28 class SentinelAeronetDataset(Dataset):
29 def __init__(self, csv_file, image_dir, is_test=False, scaler_stats=None):
30 """
31 Loads both tabular data and multi-channel Sentinel-2 .tiff images.
32 """
33 self.data = pd.read_csv(csv_file)
34 self.image_dir = image_dir
35 self.is_test = is_test
36
37 # Ensure column names match (handle potential capitalization mismatches)
38 col_map = {c.lower(): c for c in self.data.columns}
39 self.col_elev = col_map.get('elevation', 'Elevation')
40 self.col_ozone = col_map.get('ozone', 'Ozone')
41 self.col_no2 = col_map.get('no2', 'NO2')
42 self.col_aot = col_map.get('aot', 'AOT')
43 if not self.is_test:
44 self.data = self.data.dropna(subset=[self.col_aot])
45
46 # FIX: Pandas ChainedAssignmentError (Modern approach to fillna)
47 self.data[self.col_elev] = self.data[self.col_elev].fillna(self.data[self.col_elev].median())
48 self.data[self.col_ozone] = self.data[self.col_ozone].fillna(self.data[self.col_ozone].median())
49 self.data[self.col_no2] = self.data[self.col_no2].fillna(self.data[self.col_no2].median())
50
51 self.img_names = self.data['img_name'].values
52 self.X_tab = self.data[[self.col_elev, self.col_ozone, self.col_no2]].values.astype(np.float32)
53
54 # Normalization (StandardScaler logic)
55 if scaler_stats is None:
56 self.mean = self.X_tab.mean(axis=0)
57 self.std = self.X_tab.std(axis=0)
58 else:
59 # FIX: Typo corrected here (was self.self_std)
60 self.mean, self.std = scaler_stats
61
62 self.X_tab = (self.X_tab - self.mean) / (self.std + 1e-8)
63
64 # Target variable (Only available in train)
65 if not self.is_test:
66 self.y = self.data[self.col_aot].values.astype(np.float32)
67
68 def get_scaler_stats(self):
69 return self.mean, self.std
70
71 def __len__(self):
72 return len(self.data)
73
74 def __getitem__(self, idx):
75 # 1. Load Tabular features
76 tab_features = torch.tensor(self.X_tab[idx])
77
78 # 2. Load Image using tifffile
79 img_path = os.path.join(self.image_dir, self.img_names[idx])
80
81 try:
82 # tifffile can read arbitrarily sized matrices (e.g. 20 channels)
83 img = tifffile.imread(img_path)
84
85 # Basic fallback if image is totally empty
86 if img is None:
87 img = np.zeros((CHANNELS, 19, 19), dtype=np.float32)
88 else:
89 img = img.astype(np.float32)
90 img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0)
91 img = img / 10000.0
92 # Check shape. tifffile usually loads as (Height, Width, Channels) or (Channels, Height, Width)
93 # PyTorch expects (Channels, Height, Width).
94 # If the last dimension is 19 or 20, it's likely (H, W, C) so we transpose.
95 if len(img.shape) == 3 and (img.shape[2] == 19 or img.shape[2] == 20):
96 img = img.transpose(2, 0, 1)
97
98 except Exception as e:
99 # If the file is missing or corrupted, return a blank tensor so training doesn't crash
100 # print(f"Warning: Could not read {img_path}. Error: {e}")
101 img = np.zeros((CHANNELS, 19, 19), dtype=np.float32)
102
103 img_features = torch.tensor(img)
104
105 if self.is_test:
106 return img_features, tab_features, self.data['id'].iloc[idx]
107 else:
108 target = torch.tensor(self.y[idx])
109 return img_features, tab_features, target
110
111 # ---------------------------------------------------------
112 # 3. Multi-modal DNN Model Definition
113 # ---------------------------------------------------------
114 class MultiModalAOTPredictor(nn.Module):
115 def __init__(self, in_channels=CHANNELS):
116 super(MultiModalAOTPredictor, self).__init__()
117
118 # --- Image Branch (CNN) ---
119 # Input: `in_channels` (20), 19x19 pixels
120 self.cnn = nn.Sequential(
121 nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1),
122 nn.ReLU(),
123 nn.MaxPool2d(kernel_size=2), # Output: 32 x 9 x 9
124 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
125 nn.ReLU(),
126 nn.AdaptiveAvgPool2d((1, 1)), # Output: 64 x 1 x 1
127 nn.Flatten() # Output: 64
128 )
129
130 # --- Tabular Branch (MLP) ---
131 # Input: 3 features (Elevation, Ozone, NO2)
132 self.mlp = nn.Sequential(
133 nn.Linear(3, 16),
134 nn.ReLU(),
135 nn.Linear(16, 16),
136 nn.ReLU()
137 )
138
139 # Fusion Branch
140 # Combines 64 features from CNN + 16 features from MLP = 80 total
141 self.fusion = nn.Sequential(
142 nn.Linear(64 + 16, 64),
143 nn.ReLU(),
144 nn.Dropout(0.2), # Dropout to prevent overfitting
145 nn.Linear(64, 32),
146 nn.ReLU(),
147 nn.Linear(32, 1) # Predict AOT
148 )
149
150 def forward(self, img, tab):
151 img_out = self.cnn(img)
152 tab_out = self.mlp(tab)
153
154 # Concatenate features along the batch dimension
155 combined = torch.cat((img_out, tab_out), dim=1)
156
157 out = self.fusion(combined)
158 return out.squeeze()
159
160 # ---------------------------------------------------------
161 # 4. Training Function
162 # ---------------------------------------------------------
163 def train_model():
164 print(f"Using device: {DEVICE}")
165
166 # 1. Setup Datasets
167 train_dataset = SentinelAeronetDataset(TRAIN_CSV, IMAGE_DIR, is_test=False)
168 train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
169
170 # Save scaler stats so the test set is normalized exactly like the train set
171 scaler_stats = train_dataset.get_scaler_stats()
172
173 # 2. Initialize Model
174 model = MultiModalAOTPredictor(in_channels=CHANNELS).to(DEVICE)
175 criterion = nn.MSELoss()
176 optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
177
178 # 3. Training Loop
179 epoch_losses = []
180 print("Starting Training...")
181
182 for epoch in range(EPOCHS):
183 model.train()
184 running_loss = 0.0
185
186 for imgs, tabs, targets in train_loader:
187 imgs, tabs, targets = imgs.to(DEVICE), tabs.to(DEVICE), targets.to(DEVICE)
188
189 optimizer.zero_grad()
190 predictions = model(imgs, tabs)
191
192 loss = criterion(predictions, targets)
193 loss.backward()
194 # --- NEW: Gradient Clipping to prevent "Exploding Gradients" ---
195 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
196 optimizer.step()
197
198 running_loss += loss.item() * imgs.size(0)
199
200 epoch_loss = running_loss / len(train_dataset)
201 epoch_losses.append(epoch_loss)
202
203 if (epoch + 1) % 5 == 0 or epoch == 0:
204 print(f"Epoch [{epoch+1}/{EPOCHS}], Loss (MSE): {epoch_loss:.6f}")
205
206 print("Training complete!")
207 torch.save(model.state_dict(), "aot_multimodal_model.pth")
208 return model, epoch_losses, scaler_stats
209
210 # ---------------------------------------------------------
211 # 5. Prediction Function for Submission
212 # ---------------------------------------------------------
213 def generate_submission(model, scaler_stats):
214 print("\nGenerating predictions for test.csv...")
215 model.eval()
216
217 test_dataset = SentinelAeronetDataset(TEST_CSV, IMAGE_DIR, is_test=True, scaler_stats=scaler_stats)
218 test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
219
220 ids = []
221 predictions = []
222
223 with torch.no_grad():
224 for imgs, tabs, batch_ids in test_loader:
225 imgs, tabs = imgs.to(DEVICE), tabs.to(DEVICE)
226
227 preds = model(imgs, tabs)
228
229 ids.extend(batch_ids.numpy())
230 predictions.extend(preds.cpu().numpy())
231
232 # Create submission dataframe
233 submission_df = pd.DataFrame({
234 'id': ids,
235 'AOT': predictions
236 })
237
238 submission_df.to_csv(SUBMISSION_FILE, index=False)
239 print(f"Submission saved to {SUBMISSION_FILE}")
240
241 # ---------------------------------------------------------
242 # 6. Execution
243 # ---------------------------------------------------------
244 if __name__ == "__main__":
245 if not os.path.exists(IMAGE_DIR):
246 print(f"WARNING: Image directory '{IMAGE_DIR}' not found. Please update IMAGE_DIR or place images there.")
247
248 if os.path.exists(TRAIN_CSV):
249 # 1. Train
250 trained_model, losses, scaler_stats = train_model()
251
252 # Plot Loss
253 plt.figure(figsize=(8, 5))
254 plt.plot(range(1, len(losses) + 1), losses, marker='o', color='r', label='Train MSE')
255 plt.title("Multimodal AOT Model Training Loss")
256 plt.xlabel("Epoch")
257 plt.ylabel("MSE")
258 plt.grid(True)
259 plt.savefig("loss_plot.png")
260 print("Saved loss plot to 'loss_plot.png'")
261
262 # 2. Predict on Test set
263 if os.path.exists(TEST_CSV):
264 generate_submission(trained_model, scaler_stats)
265 else:
266 print(f"Test file '{TEST_CSV}' not found. Skipping submission generation.")
267 else:
268 print(f"Training file '{TRAIN_CSV}' not found. Please ensure it is in the same directory.")