|
import unittest |
|
|
|
import torch |
|
|
|
from src.data.utils import ( |
|
S2_BANDS, |
|
SPACE_TIME_BANDS, |
|
SPACE_TIME_BANDS_GROUPS_IDX, |
|
construct_galileo_input, |
|
) |
|
|
|
|
|
class TestDataUtils(unittest.TestCase): |
|
def test_construct_galileo_input_s2(self): |
|
t, h, w = 2, 4, 4 |
|
s2 = torch.randn((t, h, w, len(S2_BANDS))) |
|
for normalize in [True, False]: |
|
masked_output = construct_galileo_input(s2=s2, normalize=normalize) |
|
|
|
self.assertTrue((masked_output.space_mask == 1).all()) |
|
self.assertTrue((masked_output.time_mask == 1).all()) |
|
self.assertTrue((masked_output.static_mask == 1).all()) |
|
|
|
|
|
not_s2 = [ |
|
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" not in key |
|
] |
|
self.assertTrue((masked_output.space_time_mask[:, :, :, not_s2] == 1).all()) |
|
|
|
s2_mask_indices = [ |
|
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key |
|
] |
|
self.assertTrue((masked_output.space_time_mask[:, :, :, s2_mask_indices] == 0).all()) |
|
|
|
|
|
if not normalize: |
|
s2_indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in S2_BANDS] |
|
self.assertTrue(torch.equal(masked_output.space_time_x[:, :, :, s2_indices], s2)) |
|
|