From: VGoncalo Date: Wed, 20 May 2026 15:34:47 +0000 (+0100) Subject: code executed in ue niia server X-Git-Url: https://vgcfreebox.myrthtech.pt/gitweb/ue-rnap-aerossol.git/commitdiff_plain/refs/heads/master code executed in ue niia server --- diff --git a/submission.csv b/submission.csv new file mode 100644 index 0000000..08b617c --- /dev/null +++ b/submission.csv @@ -0,0 +1,899 @@ +id,AOT +4001,0.093698755 +4002,0.0934955 +4003,0.09320606 +4004,0.09301098 +4005,0.09296049 +4006,0.09242134 +4007,0.092213765 +4008,0.09204534 +4009,0.09170568 +4010,0.091365635 +4011,0.0912025 +4012,0.0908639 +4013,0.090713754 +4014,0.090407714 +4015,0.09009881 +4016,0.08972016 +4017,0.089501694 +4018,0.089420155 +4019,0.0896641 +4020,0.08987181 +4021,0.09009196 +4022,0.091106474 +4023,0.09139861 +4024,0.09165385 +4025,0.08950886 +4026,0.08942166 +4027,0.08987242 +4028,0.09058423 +4029,0.09085217 +4030,0.091106474 +4031,0.091399446 +4032,0.09194684 +4033,0.09202148 +4034,0.09554833 +4035,0.09993254 +4036,0.10862975 +4037,0.119597435 +4038,0.13422403 +4039,0.14346598 +4040,0.15285078 +4041,0.15975782 +4042,0.16305384 +4043,0.16369657 +4044,0.1643016 +4045,0.16102274 +4046,0.15339942 +4047,0.1475589 +4048,0.14188504 +4049,0.13698417 +4050,0.12893839 +4051,0.12359901 +4052,0.1227466 +4053,0.11854093 +4054,0.119914845 +4055,0.12044503 +4056,0.12172927 +4057,0.12253411 +4058,0.123676434 +4059,0.12448305 +4060,0.12562183 +4061,0.12756898 +4062,0.12854733 +4063,0.13165706 +4064,0.13605525 +4065,0.13905393 +4066,0.14376584 +4067,0.14690687 +4068,0.14316912 +4069,0.1427359 +4070,0.14270066 +4071,0.14247605 +4072,0.13667251 +4073,0.12280482 +4074,0.10880937 +4075,0.10693039 +4076,0.10544062 +4077,0.101154044 +4078,0.100112334 +4079,0.098669946 +4080,0.09741275 +4081,0.09612964 +4082,0.09340565 +4083,0.09335306 +4084,0.09266491 +4085,0.092347816 +4086,0.09231754 +4087,0.092031464 +4088,0.09186144 +4089,0.090969756 +4090,0.090636775 +4091,0.09047721 +4092,0.09015469 +4093,0.08987893 +4094,0.08950926 +4095,0.08987181 +4096,0.09035547 +4097,0.09058423 +4098,0.09085217 +4099,0.09139946 +4100,0.09165385 +4101,0.09202148 +4102,0.09265606 +4103,0.09462674 +4104,0.09637958 +4105,0.0978567 +4106,0.10093103 +4107,0.1034828 +4108,0.12660284 +4109,0.14346284 +4110,0.15771084 +4111,0.16209945 +4112,0.16230504 +4113,0.16369736 +4114,0.15864545 +4115,0.14755124 +4116,0.14188848 +4117,0.13987769 +4118,0.13500558 +4119,0.13260496 +4120,0.12894379 +4121,0.12596735 +4122,0.123600796 +4123,0.12274483 +4124,0.12040852 +4125,0.12044503 +4126,0.123676434 +4127,0.124481276 +4128,0.12854733 +4129,0.130316 +4130,0.13165706 +4131,0.13604124 +4132,0.13905393 +4133,0.1480599 +4134,0.14690687 +4135,0.14628594 +4136,0.14316937 +4137,0.14270024 +4138,0.1424752 +4139,0.14198543 +4140,0.13669325 +4141,0.13165143 +4142,0.12804489 +4143,0.12280987 +4144,0.11990428 +4145,0.110210165 +4146,0.10543738 +4147,0.100112334 +4148,0.098664135 +4149,0.09740625 +4150,0.09522845 +4151,0.09426376 +4152,0.093938306 +4153,0.09360297 +4154,0.09340565 +4155,0.09335306 +4156,0.09287633 +4157,0.09270127 +4158,0.09266491 +4159,0.09235026 +4160,0.09231675 +4161,0.0920331 +4162,0.09150565 +4163,0.09132826 +4164,0.09637958 +4165,0.097858325 +4166,0.09872876 +4167,0.09993072 +4168,0.10622549 +4169,0.1086462 +4170,0.11959238 +4171,0.12953232 +4172,0.14349425 +4173,0.14689872 +4174,0.15290122 +4175,0.16230606 +4176,0.16369759 +4177,0.16430633 +4178,0.16101481 +4179,0.16009386 +4180,0.15341873 +4181,0.1511376 +4182,0.1475589 +4183,0.1398744 +4184,0.12893572 +4185,0.12762451 +4186,0.1259709 +4187,0.12496559 +4188,0.12273599 +4189,0.121237755 +4190,0.11932494 +4191,0.11991066 +4192,0.12253411 +4193,0.12448482 +4194,0.12642488 +4195,0.13031325 +4196,0.13605525 +4197,0.14688331 +4198,0.14805862 +4199,0.14628619 +4200,0.14495264 +4201,0.14380072 +4202,0.14316937 +4203,0.14269942 +4204,0.14198373 +4205,0.13666014 +4206,0.13164203 +4207,0.1228223 +4208,0.11255805 +4209,0.11020802 +4210,0.10881786 +4211,0.10692693 +4212,0.10543902 +4213,0.103829935 +4214,0.10276316 +4215,0.10114919 +4216,0.0986729 +4217,0.09741275 +4218,0.09612964 +4219,0.09522699 +4220,0.09426376 +4221,0.094193205 +4222,0.093936965 +4223,0.09360297 +4224,0.0934023 +4225,0.09335193 +4226,0.0931557 +4227,0.09203555 +4228,0.09187941 +4229,0.09103516 +4230,0.090707794 +4231,0.09040177 +4232,0.09025288 +4233,0.08949831 +4234,0.08948775 +4235,0.08952136 +4236,0.08984493 +4237,0.091181085 +4238,0.09279272 +4239,0.09533377 +4240,0.09611583 +4241,0.09725161 +4242,0.09812206 +4243,0.10459374 +4244,0.11822097 +4245,0.120958805 +4246,0.12807164 +4247,0.13258661 +4248,0.15048465 +4249,0.15451346 +4250,0.15901576 +4251,0.16060664 +4252,0.1616645 +4253,0.1556833 +4254,0.15416007 +4255,0.14624982 +4256,0.13786806 +4257,0.12969594 +4258,0.12839189 +4259,0.12647267 +4260,0.1254709 +4261,0.13668619 +4262,0.13542064 +4263,0.1350191 +4264,0.13681743 +4265,0.13668787 +4266,0.13888718 +4267,0.14441393 +4268,0.15047865 +4269,0.15519905 +4270,0.15525062 +4271,0.15521716 +4272,0.15341619 +4273,0.14887933 +4274,0.14706238 +4275,0.13779534 +4276,0.11651607 +4277,0.113120005 +4278,0.110670835 +4279,0.10736279 +4280,0.10239899 +4281,0.10423879 +4282,0.095999226 +4283,0.094906375 +4284,0.09335993 +4285,0.08871199 +4286,0.088600725 +4287,0.08841184 +4288,0.08813399 +4289,0.08783397 +4290,0.0874107 +4291,0.08702034 +4292,0.08693574 +4293,0.08661932 +4294,0.0867746 +4295,0.08682412 +4296,0.08690052 +4297,0.0869274 +4298,0.086977065 +4299,0.08702181 +4300,0.08709836 +4301,0.087148026 +4302,0.08729647 +4303,0.08735745 +4304,0.08758204 +4305,0.087793395 +4306,0.088135034 +4307,0.088360175 +4308,0.08872406 +4309,0.08959977 +4310,0.090902254 +4311,0.09224467 +4312,0.093508035 +4313,0.09540065 +4314,0.10101305 +4315,0.10371865 +4316,0.10541263 +4317,0.1138657 +4318,0.13315882 +4319,0.1350394 +4320,0.1358773 +4321,0.136393 +4322,0.1360662 +4323,0.13557144 +4324,0.13524379 +4325,0.13490152 +4326,0.13495357 +4327,0.13520996 +4328,0.13542907 +4329,0.13598146 +4330,0.13620059 +4331,0.13653027 +4332,0.1367494 +4333,0.13705559 +4334,0.13683237 +4335,0.13690273 +4336,0.13634402 +4337,0.13612036 +4338,0.1364804 +4339,0.13682108 +4340,0.13712858 +4341,0.13746928 +4342,0.13777676 +4343,0.1384245 +4344,0.13900627 +4345,0.13965403 +4346,0.13794623 +4347,0.1363631 +4348,0.1347122 +4349,0.13446774 +4350,0.12980385 +4351,0.12842055 +4352,0.12710476 +4353,0.12412867 +4354,0.12247448 +4355,0.12059064 +4356,0.11953515 +4357,0.11765131 +4358,0.1159676 +4359,0.11405997 +4360,0.11130078 +4361,0.10989578 +4362,0.108287126 +4363,0.10549612 +4364,0.10425024 +4365,0.09900491 +4366,0.09599492 +4367,0.09489916 +4368,0.09335773 +4369,0.09258315 +4370,0.0908287 +4371,0.089148715 +4372,0.088893935 +4373,0.09474872 +4374,0.09603158 +4375,0.09822036 +4376,0.10005315 +4377,0.102681205 +4378,0.107201144 +4379,0.11295502 +4380,0.11950719 +4381,0.12201922 +4382,0.1259496 +4383,0.13520886 +4384,0.13570872 +4385,0.13655639 +4386,0.13622835 +4387,0.13573442 +4388,0.13496685 +4389,0.13532014 +4390,0.1364205 +4391,0.13663921 +4392,0.13679175 +4393,0.13623177 +4394,0.1365136 +4395,0.13716179 +4396,0.13750249 +4397,0.13774355 +4398,0.1380506 +4399,0.1383913 +4400,0.13869877 +4401,0.13903949 +4402,0.13968767 +4403,0.13657151 +4404,0.12912412 +4405,0.12037353 +4406,0.11788593 +4407,0.11681168 +4408,0.11489545 +4409,0.11323707 +4410,0.11148958 +4411,0.10898964 +4412,0.10617854 +4413,0.10487564 +4414,0.0984498 +4415,0.09113969 +4416,0.090264335 +4417,0.08977628 +4418,0.15050109 +4419,0.12644503 +4420,0.100467786 +4421,0.09985261 +4422,0.097856775 +4423,0.096616656 +4424,0.0954604 +4425,0.09426142 +4426,0.091863826 +4427,0.09635633 +4428,0.09857887 +4429,0.09901145 +4430,0.10038744 +4431,0.101805195 +4432,0.10354106 +4433,0.105244145 +4434,0.1071434 +4435,0.109144315 +4436,0.117303476 +4437,0.124700144 +4438,0.12788624 +4439,0.13293345 +4440,0.14490403 +4441,0.15800986 +4442,0.16355468 +4443,0.1717701 +4444,0.1700889 +4445,0.17815286 +4446,0.18604816 +4447,0.18695906 +4448,0.1783751 +4449,0.16474979 +4450,0.15854515 +4451,0.1505038 +4452,0.12644503 +4453,0.10419005 +4454,0.09985261 +4455,0.098536804 +4456,0.097856775 +4457,0.097176805 +4458,0.0954604 +4459,0.09426139 +4460,0.09306255 +4461,0.089829326 +4462,0.09240878 +4463,0.09843035 +4464,0.09849815 +4465,0.09851043 +4466,0.09857887 +4467,0.10524556 +4468,0.117303476 +4469,0.13871709 +4470,0.14188999 +4471,0.14490403 +4472,0.14811878 +4473,0.15308161 +4474,0.1580076 +4475,0.1635543 +4476,0.16935329 +4477,0.15854515 +4478,0.15050246 +4479,0.1139144 +4480,0.10419005 +4481,0.09985261 +4482,0.097856805 +4483,0.097176865 +4484,0.09661654 +4485,0.09426133 +4486,0.09306249 +4487,0.09075008 +4488,0.08982949 +4489,0.094310805 +4490,0.11118433 +4491,0.12731902 +4492,0.13226257 +4493,0.13532464 +4494,0.14729546 +4495,0.15196161 +4496,0.15754925 +4497,0.16634257 +4498,0.16949223 +4499,0.17229176 +4500,0.19282962 +4501,0.2041332 +4502,0.18894657 +4503,0.17307839 +4504,0.16580717 +4505,0.15958546 +4506,0.10545562 +4507,0.10087633 +4508,0.099955246 +4509,0.09864871 +4510,0.09796885 +4511,0.09740858 +4512,0.09672861 +4513,0.09572412 +4514,0.09332608 +4515,0.09212713 +4516,0.090951264 +4517,0.090032846 +4518,0.09055714 +4519,0.103167936 +4520,0.104859665 +4521,0.10675855 +4522,0.113729864 +4523,0.107894555 +4524,0.105991915 +4525,0.10106027 +4526,0.09859429 +4527,0.09648207 +4528,0.09483586 +4529,0.093522295 +4530,0.09290032 +4531,0.09227835 +4532,0.09113927 +4533,0.09058902 +4534,0.090121254 +4535,0.08926915 +4536,0.08898629 +4537,0.08862211 +4538,0.08896677 +4539,0.08930293 +4540,0.08997981 +4541,0.09040442 +4542,0.09174712 +4543,0.092373505 +4544,0.09273626 +4545,0.093087226 +4546,0.09343819 +4547,0.094012424 +4548,0.09505521 +4549,0.09564626 +4550,0.095346734 +4551,0.09511289 +4552,0.0945161 +4553,0.09413271 +4554,0.0944543 +4555,0.09443031 +4556,0.09468113 +4557,0.09494825 +4558,0.09567504 +4559,0.09640892 +4560,0.09754254 +4561,0.09824251 +4562,0.09895907 +4563,0.101051405 +4564,0.10268338 +4565,0.105115995 +4566,0.10680927 +4567,0.1085072 +4568,0.10964434 +4569,0.11179425 +4570,0.110781536 +4571,0.11033453 +4572,0.109334975 +4573,0.10834025 +4574,0.107892185 +4575,0.103136316 +4576,0.09859429 +4577,0.09648207 +4578,0.09483586 +4579,0.09352319 +4580,0.09290032 +4581,0.092275515 +4582,0.09168869 +4583,0.09113927 +4584,0.09058818 +4585,0.08926915 +4586,0.08898668 +4587,0.08846168 +4588,0.08930293 +4589,0.08997907 +4590,0.09237538 +4591,0.093087226 +4592,0.09343819 +4593,0.09447524 +4594,0.09505521 +4595,0.095641315 +4596,0.095346734 +4597,0.09511289 +4598,0.0945161 +4599,0.09413271 +4600,0.09396663 +4601,0.094196424 +4602,0.0944543 +4603,0.09468113 +4604,0.096039504 +4605,0.09640725 +4606,0.09895907 +4607,0.09978445 +4608,0.101051405 +4609,0.10186939 +4610,0.10268538 +4611,0.10400428 +4612,0.1096396 +4613,0.11133808 +4614,0.111789495 +4615,0.110781536 +4616,0.11033216 +4617,0.109334975 +4618,0.10834025 +4619,0.107894555 +4620,0.103136316 +4621,0.09859429 +4622,0.09483586 +4623,0.09290032 +4624,0.09227835 +4625,0.09168869 +4626,0.09113842 +4627,0.090122774 +4628,0.089776725 +4629,0.08898707 +4630,0.088461936 +4631,0.088752344 +4632,0.08930369 +4633,0.08997831 +4634,0.09237538 +4635,0.09273626 +4636,0.093087226 +4637,0.09343819 +4638,0.094012424 +4639,0.09505521 +4640,0.095346734 +4641,0.0941339 +4642,0.09396444 +4643,0.09396663 +4644,0.094196424 +4645,0.09470381 +4646,0.09468113 +4647,0.09494959 +4648,0.09567174 +4649,0.098244205 +4650,0.09895907 +4651,0.09978251 +4652,0.10186939 +4653,0.10268538 +4654,0.1040086 +4655,0.10512066 +4656,0.10680927 +4657,0.1085072 +4658,0.1096396 +4659,0.11134523 +4660,0.111789495 +4661,0.11077915 +4662,0.11033453 +4663,0.109334975 +4664,0.107892185 +4665,0.10598488 +4666,0.10106228 +4667,0.09859772 +4668,0.094831765 +4669,0.09289749 +4670,0.09227739 +4671,0.0905907 +4672,0.09012279 +4673,0.089776725 +4674,0.08926915 +4675,0.08898668 +4676,0.088810965 +4677,0.088461936 +4678,0.08875272 +4679,0.08896598 +4680,0.08997907 +4681,0.09039082 +4682,0.09092903 +4683,0.091669604 +4684,0.09226565 +4685,0.09301533 +4686,0.09336914 +4687,0.09462696 +4688,0.097006604 +4689,0.09962 +4690,0.10088934 +4691,0.102521375 +4692,0.10368723 +4693,0.10532877 +4694,0.10985239 +4695,0.1105344 +4696,0.1095373 +4697,0.108092055 +4698,0.10657756 +4699,0.10368289 +4700,0.10156228 +4701,0.09904529 +4702,0.09689812 +4703,0.09519598 +4704,0.09306584 +4705,0.09244391 +4706,0.09183617 +4707,0.091285065 +4708,0.09073484 +4709,0.08937113 +4710,0.08886264 +4711,0.08860929 +4712,0.0885883 +4713,0.088923946 +4714,0.089476004 +4715,0.08989596 +4716,0.0904655 +4717,0.09124267 +4718,0.09268424 +4719,0.092806265 +4720,0.09336819 +4721,0.09350915 +4722,0.09372009 +4723,0.15603328 +4724,0.1564891 +4725,0.15610193 +4726,0.1340294 +4727,0.12081754 +4728,0.116007224 +4729,0.116678104 +4730,0.11772539 +4731,0.15861523 +4732,0.15821946 +4733,0.1415671 +4734,0.11476153 +4735,0.112388045 +4736,0.1092713 +4737,0.10442947 +4738,0.104265064 +4739,0.10416539 +4740,0.10409905 +4741,0.1039999 +4742,0.10548146 +4743,0.10649876 +4744,0.10990821 +4745,0.11289181 +4746,0.11556922 +4747,0.115361765 +4748,0.115774184 +4749,0.117459714 +4750,0.120900705 +4751,0.1302199 +4752,0.15056245 +4753,0.15162143 +4754,0.15203215 +4755,0.15476906 +4756,0.15567331 +4757,0.15715507 +4758,0.15665567 +4759,0.15245998 +4760,0.11710377 +4761,0.11640014 +4762,0.14047684 +4763,0.1148399 +4764,0.11480759 +4765,0.11061774 +4766,0.10945195 +4767,0.10831463 +4768,0.10714671 +4769,0.10601263 +4770,0.1043642 +4771,0.10426451 +4772,0.10416485 +4773,0.10443245 +4774,0.10548146 +4775,0.108731374 +4776,0.11289181 +4777,0.1155584 +4778,0.11551249 +4779,0.115361765 +4780,0.11577849 +4781,0.120900705 +4782,0.12506957 +4783,0.1302199 +4784,0.13767289 +4785,0.14534725 +4786,0.15056245 +4787,0.15162168 +4788,0.15203215 +4789,0.1534238 +4790,0.15476932 +4791,0.15567331 +4792,0.15657757 +4793,0.15715507 +4794,0.15245618 +4795,0.12795451 +4796,0.1164027 +4797,0.11716525 +4798,0.11731081 +4799,0.117458925 +4800,0.12975325 +4801,0.115822166 +4802,0.1149282 +4803,0.11480759 +4804,0.11474417 +4805,0.11276162 +4806,0.106011 +4807,0.104879096 +4808,0.1043642 +4809,0.10416539 +4810,0.1039999 +4811,0.107566625 +4812,0.10873018 +4813,0.1099094 +4814,0.111290246 +4815,0.112892985 +4816,0.11449815 +4817,0.11558002 +4818,0.115512446 +4819,0.11543889 +4820,0.115361765 +4821,0.115774184 +4822,0.117461935 +4823,0.120903686 +4824,0.12506957 +4825,0.1302199 +4826,0.13767289 +4827,0.1453712 +4828,0.15056743 +4829,0.15162143 +4830,0.1534238 +4831,0.15567331 +4832,0.15715672 +4833,0.15664907 +4834,0.15245998 +4835,0.1446038 +4836,0.13583507 +4837,0.12792733 +4838,0.12189618 +4839,0.11709359 +4840,0.11640014 +4841,0.117308244 +4842,0.11745131 +4843,0.11759396 +4844,0.1584254 +4845,0.15982053 +4846,0.119740754 +4847,0.115820765 +4848,0.11476162 +4849,0.10487746 +4850,0.10649876 +4851,0.107566625 +4852,0.10873078 +4853,0.11288941 +4854,0.115580186 +4855,0.11556922 +4856,0.115774184 +4857,0.117460266 +4858,0.120903686 +4859,0.12507263 +4860,0.13022311 +4861,0.13767289 +4862,0.14534725 +4863,0.15056247 +4864,0.1534238 +4865,0.1547675 +4866,0.15567331 +4867,0.1565773 +4868,0.15715338 +4869,0.15665567 +4870,0.15245998 +4871,0.1446038 +4872,0.12792733 +4873,0.12189366 +4874,0.11710884 +4875,0.1164027 +4876,0.11716269 +4877,0.11731081 +4878,0.11975387 +4879,0.11492817 +4880,0.11483988 +4881,0.11480765 +4882,0.11476162 +4883,0.10445003 +4884,0.10408576 +4885,0.1040194 +4886,0.10526425 +4887,0.1073349 +4888,0.10850662 +4889,0.11096631 +4890,0.11256848 +4891,0.114185706 +4892,0.11552365 +4893,0.12459682 +4894,0.11627269 +4895,0.11503421 +4896,0.1108439 +4897,0.10967596 +4898,0.10854189 diff --git a/train-aot-multimodal-tifffile.py b/train-aot-multimodal-tifffile.py new file mode 100644 index 0000000..996422b --- /dev/null +++ b/train-aot-multimodal-tifffile.py @@ -0,0 +1,268 @@ +import os +import pandas as pd +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +import matplotlib.pyplot as plt +import tifffile # Replaced cv2 with tifffile for multi-channel support + +# --------------------------------------------------------- +# 1. Configuration & Hyperparameters +# --------------------------------------------------------- +TRAIN_CSV = 'data/SENTINEL2-AEROSSOL-DATA/train.csv' +TEST_CSV = 'data/SENTINEL2-AEROSSOL-DATA/test.csv' +IMAGE_DIR = 'data/SENTINEL2-AEROSSOL-DATA/tiff-images/' # Update this to the folder containing your .tiff images +SUBMISSION_FILE = 'submission.csv' + +BATCH_SIZE = 32 +LEARNING_RATE = 0.001 +EPOCHS = 50 +CHANNELS = 20 # Updated to 20 based on the OpenCV error log showing 20 channels +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# --------------------------------------------------------- +# 2. Custom Dataset Definition (Multi-modal) +# --------------------------------------------------------- +class SentinelAeronetDataset(Dataset): + def __init__(self, csv_file, image_dir, is_test=False, scaler_stats=None): + """ + Loads both tabular data and multi-channel Sentinel-2 .tiff images. + """ + self.data = pd.read_csv(csv_file) + self.image_dir = image_dir + self.is_test = is_test + + # Ensure column names match (handle potential capitalization mismatches) + col_map = {c.lower(): c for c in self.data.columns} + self.col_elev = col_map.get('elevation', 'Elevation') + self.col_ozone = col_map.get('ozone', 'Ozone') + self.col_no2 = col_map.get('no2', 'NO2') + self.col_aot = col_map.get('aot', 'AOT') + if not self.is_test: + self.data = self.data.dropna(subset=[self.col_aot]) + + # FIX: Pandas ChainedAssignmentError (Modern approach to fillna) + self.data[self.col_elev] = self.data[self.col_elev].fillna(self.data[self.col_elev].median()) + self.data[self.col_ozone] = self.data[self.col_ozone].fillna(self.data[self.col_ozone].median()) + self.data[self.col_no2] = self.data[self.col_no2].fillna(self.data[self.col_no2].median()) + + self.img_names = self.data['img_name'].values + self.X_tab = self.data[[self.col_elev, self.col_ozone, self.col_no2]].values.astype(np.float32) + + # Normalization (StandardScaler logic) + if scaler_stats is None: + self.mean = self.X_tab.mean(axis=0) + self.std = self.X_tab.std(axis=0) + else: + # FIX: Typo corrected here (was self.self_std) + self.mean, self.std = scaler_stats + + self.X_tab = (self.X_tab - self.mean) / (self.std + 1e-8) + + # Target variable (Only available in train) + if not self.is_test: + self.y = self.data[self.col_aot].values.astype(np.float32) + + def get_scaler_stats(self): + return self.mean, self.std + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + # 1. Load Tabular features + tab_features = torch.tensor(self.X_tab[idx]) + + # 2. Load Image using tifffile + img_path = os.path.join(self.image_dir, self.img_names[idx]) + + try: + # tifffile can read arbitrarily sized matrices (e.g. 20 channels) + img = tifffile.imread(img_path) + + # Basic fallback if image is totally empty + if img is None: + img = np.zeros((CHANNELS, 19, 19), dtype=np.float32) + else: + img = img.astype(np.float32) + img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0) + img = img / 10000.0 + # Check shape. tifffile usually loads as (Height, Width, Channels) or (Channels, Height, Width) + # PyTorch expects (Channels, Height, Width). + # If the last dimension is 19 or 20, it's likely (H, W, C) so we transpose. + if len(img.shape) == 3 and (img.shape[2] == 19 or img.shape[2] == 20): + img = img.transpose(2, 0, 1) + + except Exception as e: + # If the file is missing or corrupted, return a blank tensor so training doesn't crash + # print(f"Warning: Could not read {img_path}. Error: {e}") + img = np.zeros((CHANNELS, 19, 19), dtype=np.float32) + + img_features = torch.tensor(img) + + if self.is_test: + return img_features, tab_features, self.data['id'].iloc[idx] + else: + target = torch.tensor(self.y[idx]) + return img_features, tab_features, target + +# --------------------------------------------------------- +# 3. Multi-modal DNN Model Definition +# --------------------------------------------------------- +class MultiModalAOTPredictor(nn.Module): + def __init__(self, in_channels=CHANNELS): + super(MultiModalAOTPredictor, self).__init__() + + # --- Image Branch (CNN) --- + # Input: `in_channels` (20), 19x19 pixels + self.cnn = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2), # Output: 32 x 9 x 9 + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d((1, 1)), # Output: 64 x 1 x 1 + nn.Flatten() # Output: 64 + ) + + # --- Tabular Branch (MLP) --- + # Input: 3 features (Elevation, Ozone, NO2) + self.mlp = nn.Sequential( + nn.Linear(3, 16), + nn.ReLU(), + nn.Linear(16, 16), + nn.ReLU() + ) + + # Fusion Branch + # Combines 64 features from CNN + 16 features from MLP = 80 total + self.fusion = nn.Sequential( + nn.Linear(64 + 16, 64), + nn.ReLU(), + nn.Dropout(0.2), # Dropout to prevent overfitting + nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 1) # Predict AOT + ) + + def forward(self, img, tab): + img_out = self.cnn(img) + tab_out = self.mlp(tab) + + # Concatenate features along the batch dimension + combined = torch.cat((img_out, tab_out), dim=1) + + out = self.fusion(combined) + return out.squeeze() + +# --------------------------------------------------------- +# 4. Training Function +# --------------------------------------------------------- +def train_model(): + print(f"Using device: {DEVICE}") + + # 1. Setup Datasets + train_dataset = SentinelAeronetDataset(TRAIN_CSV, IMAGE_DIR, is_test=False) + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) + + # Save scaler stats so the test set is normalized exactly like the train set + scaler_stats = train_dataset.get_scaler_stats() + + # 2. Initialize Model + model = MultiModalAOTPredictor(in_channels=CHANNELS).to(DEVICE) + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) + + # 3. Training Loop + epoch_losses = [] + print("Starting Training...") + + for epoch in range(EPOCHS): + model.train() + running_loss = 0.0 + + for imgs, tabs, targets in train_loader: + imgs, tabs, targets = imgs.to(DEVICE), tabs.to(DEVICE), targets.to(DEVICE) + + optimizer.zero_grad() + predictions = model(imgs, tabs) + + loss = criterion(predictions, targets) + loss.backward() + # --- NEW: Gradient Clipping to prevent "Exploding Gradients" --- + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + + running_loss += loss.item() * imgs.size(0) + + epoch_loss = running_loss / len(train_dataset) + epoch_losses.append(epoch_loss) + + if (epoch + 1) % 5 == 0 or epoch == 0: + print(f"Epoch [{epoch+1}/{EPOCHS}], Loss (MSE): {epoch_loss:.6f}") + + print("Training complete!") + torch.save(model.state_dict(), "aot_multimodal_model.pth") + return model, epoch_losses, scaler_stats + +# --------------------------------------------------------- +# 5. Prediction Function for Submission +# --------------------------------------------------------- +def generate_submission(model, scaler_stats): + print("\nGenerating predictions for test.csv...") + model.eval() + + test_dataset = SentinelAeronetDataset(TEST_CSV, IMAGE_DIR, is_test=True, scaler_stats=scaler_stats) + test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) + + ids = [] + predictions = [] + + with torch.no_grad(): + for imgs, tabs, batch_ids in test_loader: + imgs, tabs = imgs.to(DEVICE), tabs.to(DEVICE) + + preds = model(imgs, tabs) + + ids.extend(batch_ids.numpy()) + predictions.extend(preds.cpu().numpy()) + + # Create submission dataframe + submission_df = pd.DataFrame({ + 'id': ids, + 'AOT': predictions + }) + + submission_df.to_csv(SUBMISSION_FILE, index=False) + print(f"Submission saved to {SUBMISSION_FILE}") + +# --------------------------------------------------------- +# 6. Execution +# --------------------------------------------------------- +if __name__ == "__main__": + if not os.path.exists(IMAGE_DIR): + print(f"WARNING: Image directory '{IMAGE_DIR}' not found. Please update IMAGE_DIR or place images there.") + + if os.path.exists(TRAIN_CSV): + # 1. Train + trained_model, losses, scaler_stats = train_model() + + # Plot Loss + plt.figure(figsize=(8, 5)) + plt.plot(range(1, len(losses) + 1), losses, marker='o', color='r', label='Train MSE') + plt.title("Multimodal AOT Model Training Loss") + plt.xlabel("Epoch") + plt.ylabel("MSE") + plt.grid(True) + plt.savefig("loss_plot.png") + print("Saved loss plot to 'loss_plot.png'") + + # 2. Predict on Test set + if os.path.exists(TEST_CSV): + generate_submission(trained_model, scaler_stats) + else: + print(f"Test file '{TEST_CSV}' not found. Skipping submission generation.") + else: + print(f"Training file '{TRAIN_CSV}' not found. Please ensure it is in the same directory.") \ No newline at end of file