6 import torch
.optim
as optim
7 from torch
.utils
.data
import Dataset
, DataLoader
8 import matplotlib
.pyplot
as plt
9 import tifffile
# Replaced cv2 with tifffile for multi-channel support
11 # ---------------------------------------------------------
12 # 1. Configuration & Hyperparameters
13 # ---------------------------------------------------------
14 TRAIN_CSV
= 'data/SENTINEL2-AEROSSOL-DATA/train.csv'
15 TEST_CSV
= 'data/SENTINEL2-AEROSSOL-DATA/test.csv'
16 IMAGE_DIR
= 'data/SENTINEL2-AEROSSOL-DATA/tiff-images/' # Update this to the folder containing your .tiff images
17 SUBMISSION_FILE
= 'submission.csv'
22 CHANNELS
= 20 # Updated to 20 based on the OpenCV error log showing 20 channels
23 DEVICE
= torch
.device("cuda" if torch
.cuda
.is_available() else "cpu")
25 # ---------------------------------------------------------
26 # 2. Custom Dataset Definition (Multi-modal)
27 # ---------------------------------------------------------
28 class SentinelAeronetDataset(Dataset
):
29 def __init__(self
, csv_file
, image_dir
, is_test
=False, scaler_stats
=None):
31 Loads both tabular data and multi-channel Sentinel-2 .tiff images.
33 self
.data
= pd
.read_csv(csv_file
)
34 self
.image_dir
= image_dir
35 self
.is_test
= is_test
37 # Ensure column names match (handle potential capitalization mismatches)
38 col_map
= {c.lower(): c for c in self.data.columns}
39 self
.col_elev
= col_map
.get('elevation', 'Elevation')
40 self
.col_ozone
= col_map
.get('ozone', 'Ozone')
41 self
.col_no2
= col_map
.get('no2', 'NO2')
42 self
.col_aot
= col_map
.get('aot', 'AOT')
44 self
.data
= self
.data
.dropna(subset
=[self
.col_aot
])
46 # FIX: Pandas ChainedAssignmentError (Modern approach to fillna)
47 self
.data
[self
.col_elev
] = self
.data
[self
.col_elev
].fillna(self
.data
[self
.col_elev
].median())
48 self
.data
[self
.col_ozone
] = self
.data
[self
.col_ozone
].fillna(self
.data
[self
.col_ozone
].median())
49 self
.data
[self
.col_no2
] = self
.data
[self
.col_no2
].fillna(self
.data
[self
.col_no2
].median())
51 self
.img_names
= self
.data
['img_name'].values
52 self
.X_tab
= self
.data
[[self
.col_elev
, self
.col_ozone
, self
.col_no2
]].values
.astype(np
.float32
)
54 # Normalization (StandardScaler logic)
55 if scaler_stats
is None:
56 self
.mean
= self
.X_tab
.mean(axis
=0)
57 self
.std
= self
.X_tab
.std(axis
=0)
59 # FIX: Typo corrected here (was self.self_std)
60 self
.mean
, self
.std
= scaler_stats
62 self
.X_tab
= (self
.X_tab
- self
.mean
) / (self
.std
+ 1e-8)
64 # Target variable (Only available in train)
66 self
.y
= self
.data
[self
.col_aot
].values
.astype(np
.float32
)
68 def get_scaler_stats(self
):
69 return self
.mean
, self
.std
74 def __getitem__(self
, idx
):
75 # 1. Load Tabular features
76 tab_features
= torch
.tensor(self
.X_tab
[idx
])
78 # 2. Load Image using tifffile
79 img_path
= os
.path
.join(self
.image_dir
, self
.img_names
[idx
])
82 # tifffile can read arbitrarily sized matrices (e.g. 20 channels)
83 img
= tifffile
.imread(img_path
)
85 # Basic fallback if image is totally empty
87 img
= np
.zeros((CHANNELS
, 19, 19), dtype
=np
.float32
)
89 img
= img
.astype(np
.float32
)
90 img
= np
.nan_to_num(img
, nan
=0.0, posinf
=0.0, neginf
=0.0)
92 # Check shape. tifffile usually loads as (Height, Width, Channels) or (Channels, Height, Width)
93 # PyTorch expects (Channels, Height, Width).
94 # If the last dimension is 19 or 20, it's likely (H, W, C) so we transpose.
95 if len(img
.shape
) == 3 and (img
.shape
[2] == 19 or img
.shape
[2] == 20):
96 img
= img
.transpose(2, 0, 1)
98 except Exception as e
:
99 # If the file is missing or corrupted, return a blank tensor so training doesn't crash
100 # print(f"Warning: Could not read {img_path}. Error: {e}")
101 img
= np
.zeros((CHANNELS
, 19, 19), dtype
=np
.float32
)
103 img_features
= torch
.tensor(img
)
106 return img_features
, tab_features
, self
.data
['id'].iloc
[idx
]
108 target
= torch
.tensor(self
.y
[idx
])
109 return img_features
, tab_features
, target
111 # ---------------------------------------------------------
112 # 3. Multi-modal DNN Model Definition
113 # ---------------------------------------------------------
114 class MultiModalAOTPredictor(nn
.Module
):
115 def __init__(self
, in_channels
=CHANNELS
):
116 super(MultiModalAOTPredictor
, self
).__init
__()
118 # --- Image Branch (CNN) ---
119 # Input: `in_channels` (20), 19x19 pixels
120 self
.cnn
= nn
.Sequential(
121 nn
.Conv2d(in_channels
=in_channels
, out_channels
=32, kernel_size
=3, padding
=1),
123 nn
.MaxPool2d(kernel_size
=2), # Output: 32 x 9 x 9
124 nn
.Conv2d(in_channels
=32, out_channels
=64, kernel_size
=3, padding
=1),
126 nn
.AdaptiveAvgPool2d((1, 1)), # Output: 64 x 1 x 1
127 nn
.Flatten() # Output: 64
130 # --- Tabular Branch (MLP) ---
131 # Input: 3 features (Elevation, Ozone, NO2)
132 self
.mlp
= nn
.Sequential(
140 # Combines 64 features from CNN + 16 features from MLP = 80 total
141 self
.fusion
= nn
.Sequential(
142 nn
.Linear(64 + 16, 64),
144 nn
.Dropout(0.2), # Dropout to prevent overfitting
147 nn
.Linear(32, 1) # Predict AOT
150 def forward(self
, img
, tab
):
151 img_out
= self
.cnn(img
)
152 tab_out
= self
.mlp(tab
)
154 # Concatenate features along the batch dimension
155 combined
= torch
.cat((img_out
, tab_out
), dim
=1)
157 out
= self
.fusion(combined
)
160 # ---------------------------------------------------------
161 # 4. Training Function
162 # ---------------------------------------------------------
164 print(f
"Using device: {DEVICE}")
167 train_dataset
= SentinelAeronetDataset(TRAIN_CSV
, IMAGE_DIR
, is_test
=False)
168 train_loader
= DataLoader(train_dataset
, batch_size
=BATCH_SIZE
, shuffle
=True)
170 # Save scaler stats so the test set is normalized exactly like the train set
171 scaler_stats
= train_dataset
.get_scaler_stats()
173 # 2. Initialize Model
174 model
= MultiModalAOTPredictor(in_channels
=CHANNELS
).to(DEVICE
)
175 criterion
= nn
.MSELoss()
176 optimizer
= optim
.Adam(model
.parameters(), lr
=LEARNING_RATE
, weight_decay
=1e-4)
180 print("Starting Training...")
182 for epoch
in range(EPOCHS
):
186 for imgs
, tabs
, targets
in train_loader
:
187 imgs
, tabs
, targets
= imgs
.to(DEVICE
), tabs
.to(DEVICE
), targets
.to(DEVICE
)
189 optimizer
.zero_grad()
190 predictions
= model(imgs
, tabs
)
192 loss
= criterion(predictions
, targets
)
194 # --- NEW: Gradient Clipping to prevent "Exploding Gradients" ---
195 torch
.nn
.utils
.clip_grad_norm_(model
.parameters(), max_norm
=1.0)
198 running_loss
+= loss
.item() * imgs
.size(0)
200 epoch_loss
= running_loss
/ len(train_dataset
)
201 epoch_losses
.append(epoch_loss
)
203 if (epoch
+ 1) % 5 == 0 or epoch
== 0:
204 print(f
"Epoch [{epoch+1}/{EPOCHS}], Loss (MSE): {epoch_loss:.6f}")
206 print("Training complete!")
207 torch
.save(model
.state_dict(), "aot_multimodal_model.pth")
208 return model
, epoch_losses
, scaler_stats
210 # ---------------------------------------------------------
211 # 5. Prediction Function for Submission
212 # ---------------------------------------------------------
213 def generate_submission(model
, scaler_stats
):
214 print("\nGenerating predictions for test.csv...")
217 test_dataset
= SentinelAeronetDataset(TEST_CSV
, IMAGE_DIR
, is_test
=True, scaler_stats
=scaler_stats
)
218 test_loader
= DataLoader(test_dataset
, batch_size
=BATCH_SIZE
, shuffle
=False)
223 with torch
.no_grad():
224 for imgs
, tabs
, batch_ids
in test_loader
:
225 imgs
, tabs
= imgs
.to(DEVICE
), tabs
.to(DEVICE
)
227 preds
= model(imgs
, tabs
)
229 ids
.extend(batch_ids
.numpy())
230 predictions
.extend(preds
.cpu().numpy())
232 # Create submission dataframe
233 submission_df
= pd
.DataFrame({
238 submission_df
.to_csv(SUBMISSION_FILE
, index
=False)
239 print(f
"Submission saved to {SUBMISSION_FILE}")
241 # ---------------------------------------------------------
243 # ---------------------------------------------------------
244 if __name__
== "__main__":
245 if not os
.path
.exists(IMAGE_DIR
):
246 print(f
"WARNING: Image directory '{IMAGE_DIR}' not found. Please update IMAGE_DIR or place images there.")
248 if os
.path
.exists(TRAIN_CSV
):
250 trained_model
, losses
, scaler_stats
= train_model()
253 plt
.figure(figsize
=(8, 5))
254 plt
.plot(range(1, len(losses
) + 1), losses
, marker
='o', color
='r', label
='Train MSE')
255 plt
.title("Multimodal AOT Model Training Loss")
259 plt
.savefig("loss_plot.png")
260 print("Saved loss plot to 'loss_plot.png'")
262 # 2. Predict on Test set
263 if os
.path
.exists(TEST_CSV
):
264 generate_submission(trained_model
, scaler_stats
)
266 print(f
"Test file '{TEST_CSV}' not found. Skipping submission generation.")
268 print(f
"Training file '{TRAIN_CSV}' not found. Please ensure it is in the same directory.")