|
import json |
|
import tempfile |
|
import unittest |
|
from pathlib import Path |
|
|
|
import torch |
|
from einops import repeat |
|
|
|
from single_file_galileo import Encoder as SingleFileEncoder |
|
from src.data import ( |
|
SPACE_BAND_GROUPS_IDX, |
|
SPACE_TIME_BANDS_GROUPS_IDX, |
|
STATIC_BAND_GROUPS_IDX, |
|
TIME_BAND_GROUPS_IDX, |
|
Dataset, |
|
) |
|
from src.data.config import CONFIG_FILENAME, ENCODER_FILENAME |
|
from src.data.dataset import DatasetOutput |
|
from src.galileo import Decoder, Encoder |
|
from src.masking import ( |
|
MASKING_MODES, |
|
MaskingFunctions, |
|
batch_mask_space, |
|
batch_mask_time, |
|
batch_subset_mask_galileo, |
|
) |
|
from src.utils import device, load_check_config |
|
|
|
DATA_FOLDER = Path(__file__).parents[1] / "data" |
|
TIFS_FOLDER = DATA_FOLDER / "tifs" |
|
TEST_MODEL_FOLDER = Path(__file__).parents[0] / "141" |
|
|
|
|
|
class TestGalileo(unittest.TestCase): |
|
@staticmethod |
|
def to_tensor_with_batch_d(input: DatasetOutput): |
|
return ( |
|
torch.from_numpy(input.space_time_x).float().unsqueeze(0), |
|
torch.from_numpy(input.space_x).float().unsqueeze(0), |
|
torch.from_numpy(input.time_x).float().unsqueeze(0), |
|
torch.from_numpy(input.static_x).float().unsqueeze(0), |
|
torch.from_numpy(input.months).long().unsqueeze(0), |
|
) |
|
|
|
def test_end_to_end(self): |
|
self._end_to_end_run(16, 8) |
|
|
|
def test_end_to_end_different_inputs_per_dim_than_default(self): |
|
self._end_to_end_run(16, 4) |
|
|
|
def _end_to_end_run(self, embedding_size, patch_size): |
|
image_size = patch_size * 4 |
|
num_timesteps = 3 |
|
encoder = Encoder(embedding_size=embedding_size, num_heads=1) |
|
decoder = Decoder( |
|
encoder_embedding_size=embedding_size, |
|
decoder_embedding_size=embedding_size, |
|
num_heads=1, |
|
) |
|
ds = Dataset(TIFS_FOLDER, False) |
|
for i in range(len(ds)): |
|
s_t_x, sp_x, t_x, st_x, months = self.to_tensor_with_batch_d(ds[i]) |
|
masked_output = batch_subset_mask_galileo( |
|
s_t_x, |
|
sp_x, |
|
t_x, |
|
st_x, |
|
months, |
|
encode_ratio=0.25, |
|
decode_ratio=0.25, |
|
patch_size=patch_size, |
|
image_size=image_size, |
|
num_timesteps=num_timesteps, |
|
augmentation_strategies=None, |
|
masking_probabilities=[1] * len(MASKING_MODES), |
|
masking_function=MaskingFunctions.SPACE, |
|
max_unmasking_channels=4, |
|
) |
|
|
|
|
|
with torch.autocast(device_type=device.type, dtype=torch.float16): |
|
encoder_output = encoder( |
|
masked_output.space_time_x, |
|
masked_output.space_x, |
|
masked_output.time_x, |
|
masked_output.static_x, |
|
masked_output.space_time_mask, |
|
masked_output.space_mask, |
|
masked_output.time_mask, |
|
masked_output.static_mask, |
|
masked_output.months.long(), |
|
patch_size=patch_size, |
|
) |
|
output = decoder(*encoder_output) |
|
|
|
with torch.no_grad(): |
|
t_s_t, t_sp, t_t, t_st, _, _, _, _ = encoder.apply_linear_projection( |
|
masked_output.space_time_x, |
|
masked_output.space_x, |
|
masked_output.time_x, |
|
masked_output.static_x, |
|
~(masked_output.space_time_mask == 2), |
|
~(masked_output.space_mask == 2), |
|
~(masked_output.time_mask == 2), |
|
~(masked_output.static_mask == 2), |
|
patch_size, |
|
) |
|
t_s_t = encoder.blocks[0].norm1(t_s_t) |
|
t_sp = encoder.blocks[0].norm1(t_sp) |
|
t_sp = encoder.blocks[0].norm1(t_sp) |
|
t_st = encoder.blocks[0].norm1(t_st) |
|
|
|
self.assertFalse( |
|
torch.isnan( |
|
t_s_t[masked_output.space_time_mask[:, 0::patch_size, 0::patch_size] == 2] |
|
).any() |
|
) |
|
self.assertFalse( |
|
torch.isnan( |
|
t_sp[masked_output.space_mask[:, 0::patch_size, 0::patch_size] == 2] |
|
).any() |
|
) |
|
self.assertFalse(torch.isnan(t_t[masked_output.time_mask == 2]).any()) |
|
self.assertFalse(torch.isnan(t_st[masked_output.static_mask == 2]).any()) |
|
self.assertTrue( |
|
list(encoder_output[0].shape) |
|
== [ |
|
1, |
|
image_size / patch_size, |
|
image_size / patch_size, |
|
num_timesteps, |
|
len(SPACE_TIME_BANDS_GROUPS_IDX), |
|
embedding_size, |
|
] |
|
) |
|
self.assertTrue( |
|
list(encoder_output[1].shape) |
|
== [ |
|
1, |
|
image_size / patch_size, |
|
image_size / patch_size, |
|
len(SPACE_BAND_GROUPS_IDX), |
|
embedding_size, |
|
] |
|
) |
|
self.assertTrue( |
|
list(encoder_output[2].shape) |
|
== [ |
|
1, |
|
num_timesteps, |
|
len(TIME_BAND_GROUPS_IDX), |
|
embedding_size, |
|
] |
|
) |
|
self.assertTrue( |
|
list(encoder_output[3].shape) |
|
== [ |
|
1, |
|
len(STATIC_BAND_GROUPS_IDX), |
|
embedding_size, |
|
] |
|
) |
|
self.assertFalse( |
|
torch.isnan( |
|
encoder_output[0][ |
|
masked_output.space_time_mask[:, 0::patch_size, 0::patch_size] == 0 |
|
] |
|
).any() |
|
) |
|
self.assertFalse( |
|
torch.isnan( |
|
encoder_output[1][ |
|
masked_output.space_mask[:, 0::patch_size, 0::patch_size] == 0 |
|
] |
|
).any() |
|
) |
|
self.assertFalse(torch.isnan(encoder_output[2][masked_output.time_mask == 0]).any()) |
|
self.assertFalse(torch.isnan(encoder_output[3][masked_output.static_mask == 0]).any()) |
|
|
|
self.assertTrue( |
|
list(output[0].shape) |
|
== [ |
|
1, |
|
image_size / patch_size, |
|
image_size / patch_size, |
|
num_timesteps, |
|
len(SPACE_TIME_BANDS_GROUPS_IDX), |
|
embedding_size, |
|
] |
|
) |
|
self.assertTrue( |
|
list(output[1].shape) |
|
== [ |
|
1, |
|
image_size / patch_size, |
|
image_size / patch_size, |
|
len(SPACE_BAND_GROUPS_IDX), |
|
embedding_size, |
|
] |
|
) |
|
self.assertTrue( |
|
list(output[2].shape) |
|
== [1, num_timesteps, len(TIME_BAND_GROUPS_IDX), embedding_size] |
|
) |
|
self.assertTrue( |
|
list(output[3].shape) == [1, len(STATIC_BAND_GROUPS_IDX), embedding_size] |
|
) |
|
|
|
self.assertFalse( |
|
torch.isnan( |
|
output[0][masked_output.space_time_mask[:, 0::patch_size, 0::patch_size] == 2] |
|
).any() |
|
) |
|
self.assertFalse( |
|
torch.isnan( |
|
output[1][masked_output.space_mask[:, 0::patch_size, 0::patch_size] == 2] |
|
).any() |
|
) |
|
self.assertFalse(torch.isnan(output[2][masked_output.time_mask == 2]).any()) |
|
self.assertFalse(torch.isnan(output[3][masked_output.static_mask == 2]).any()) |
|
|
|
|
|
summed_output = sum([torch.sum(o) for o in output]) |
|
summed_output.backward() |
|
|
|
def test_decoder_add_masks(self): |
|
embedding_size = 16 |
|
decoder = Decoder( |
|
encoder_embedding_size=embedding_size, |
|
decoder_embedding_size=embedding_size, |
|
num_heads=1, |
|
) |
|
b, h, w, t = 5, 6, 7, 8 |
|
s_t_x = torch.ones(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX), embedding_size) |
|
s_t_m = torch.zeros(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)) |
|
s_t_m[:, :, :, 0] = 2 |
|
s_t_m[:, :, :, 1] = 1 |
|
|
|
sp_x = torch.ones(b, h, w, len(SPACE_BAND_GROUPS_IDX), embedding_size) |
|
sp_m = torch.zeros(b, h, w, len(SPACE_BAND_GROUPS_IDX)) |
|
sp_m[:, 0] = 2 |
|
sp_m[:, 1] = 1 |
|
|
|
t_x = torch.ones(b, t, len(TIME_BAND_GROUPS_IDX), embedding_size) |
|
t_m = torch.zeros(b, t, len(TIME_BAND_GROUPS_IDX)) |
|
t_m[:, 0] = 2 |
|
t_m[:, 1] = 1 |
|
|
|
st_x = torch.ones(b, len(STATIC_BAND_GROUPS_IDX), embedding_size) |
|
st_m = torch.zeros(b, len(STATIC_BAND_GROUPS_IDX)) |
|
st_m[:, 0] = 2 |
|
st_m[:, 1] = 1 |
|
|
|
with torch.no_grad(): |
|
o = decoder.add_masks(s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m) |
|
|
|
self.assertTrue((o[0][:, :, :, 0] == 0).all()) |
|
self.assertTrue((o[0][:, :, :, 1:] == 1).all()) |
|
self.assertTrue((o[1][:, 0] == 0).all()) |
|
self.assertTrue((o[1][:, 1:] == 1).all()) |
|
self.assertTrue((o[2][:, 0] == 0).all()) |
|
self.assertTrue((o[2][:, 1:] == 1).all()) |
|
self.assertTrue((o[3][:, 0] == 0).all()) |
|
self.assertTrue((o[3][:, 1:] == 1).all()) |
|
|
|
def test_mean_of_tokens(self): |
|
b, t, d, h, w, s_t_c_g, sp_c_g, t_c_g, st_c_g = 1, 2, 8, 3, 3, 5, 6, 2, 4 |
|
s_t_x = torch.ones((b, h, w, t, s_t_c_g, d)) |
|
sp_x = torch.ones((b, h, w, sp_c_g, d)) |
|
t_x = torch.ones((b, t, t_c_g, d)) |
|
st_x = torch.ones((b, st_c_g, d)) |
|
|
|
|
|
s_t_m = torch.zeros((b, h, w, t, s_t_c_g)) |
|
s_t_m[:, :, 0, :] = 1 |
|
s_t_m[:, :, :, 0] = 1 |
|
s_t_x[:, :, 0, :] = 0 |
|
s_t_x[:, :, :, 0] = 0 |
|
|
|
sp_m = torch.zeros((b, h, w, sp_c_g)) |
|
sp_m[:, -1, :] = 1 |
|
sp_x[:, -1, :] = 0 |
|
|
|
t_m = torch.zeros((b, t, t_c_g)) |
|
t_m[:, 0] = 1 |
|
t_x[:, 0] = 0 |
|
|
|
st_m = torch.zeros((b, st_c_g)) |
|
st_m[:, -1] = 1 |
|
st_x[:, -1] = 0 |
|
|
|
mean = Encoder.average_tokens(s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m) |
|
self.assertEqual(mean.shape, (b, d)) |
|
self.assertTrue((mean == 1).all()) |
|
|
|
def test_mask_and_unmask_tokens(self): |
|
b, d = 2, 2 |
|
x = torch.tensor([[0, 1, 0], [1, 0, 1]]).float() |
|
x = repeat(x, "b n -> b n d", d=d) |
|
mask = torch.tensor([[1, 0, 1], [0, 1, 0]]).float() |
|
|
|
out_x, indices, updated_mask = Encoder.remove_masked_tokens(x, mask) |
|
self.assertEqual(out_x.dtype, x.dtype) |
|
self.assertEqual(updated_mask.dtype, mask.dtype) |
|
self.assertEqual(out_x.shape, (b, 2, d)) |
|
|
|
self.assertTrue(torch.equal(out_x[1], torch.ones_like(out_x[1]))) |
|
|
|
|
|
self.assertEqual(indices[0, 0], 1) |
|
|
|
self.assertTrue(torch.equal(indices[1, :2], torch.tensor([0, 2]))) |
|
self.assertEqual(updated_mask.shape, (b, 2)) |
|
self.assertTrue(torch.equal(updated_mask, torch.Tensor([[0, 1], [0, 0]]))) |
|
|
|
|
|
final_x, final_mask = Encoder.add_removed_tokens(out_x, indices, updated_mask) |
|
self.assertEqual(final_x.dtype, x.dtype) |
|
self.assertEqual(final_mask.dtype, mask.dtype) |
|
self.assertTrue(torch.equal(final_x, x)) |
|
self.assertTrue(torch.equal(final_mask, mask)) |
|
|
|
def test_combine_x_y(self): |
|
|
|
x = torch.tensor([[14, 15, 16], [15, 16, 1]]).unsqueeze(-1) |
|
|
|
y = torch.tensor([[5, 6, 7, 8], [4, 5, 6, 7]]).unsqueeze(-1) |
|
x_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) |
|
y_mask = torch.tensor([[1, 1, 1, 1], [0, 1, 1, 1]]) |
|
indices = torch.tensor([[6, 7, 8, 4, 5, 0, 1, 2, 3], [7, 8, 3, 4, 5, 6, 0, 1, 2]]) |
|
|
|
tokens = Decoder.combine_x_y(x, y, x_mask, y_mask, indices) |
|
self.assertTrue( |
|
torch.equal( |
|
tokens, |
|
torch.tensor( |
|
[[5, 6, 7, 8, 0, 0, 14, 15, 16], [5, 6, 7, 0, 0, 0, 0, 15, 16]] |
|
).unsqueeze(-1), |
|
) |
|
) |
|
|
|
def test_split_x_y(self): |
|
tokens = torch.tensor( |
|
[[5, 6, 7, 8, 2, 13, 14, 15, 16], [5, 6, 7, 1, 2, 3, 4, 15, 16]] |
|
).unsqueeze(-1) |
|
mask = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2], [0, 0, 0, 1, 1, 1, 1, 2, 2]]) |
|
|
|
x, y, x_mask, y_mask, _ = Decoder.split_x_y(tokens, mask) |
|
self.assertTrue(torch.equal(x, torch.tensor([[14, 15, 16], [15, 16, 1]]).unsqueeze(-1))) |
|
self.assertTrue(torch.equal(y, torch.tensor([[5, 6, 7, 8], [4, 5, 6, 7]]).unsqueeze(-1))) |
|
self.assertTrue(torch.equal(x_mask, torch.tensor([[1, 1, 1], [1, 1, 0]]))) |
|
self.assertTrue(torch.equal(y_mask, torch.tensor([[1, 1, 1, 1], [0, 1, 1, 1]]))) |
|
|
|
def test_x_y_there_and_back_again(self): |
|
tokens = torch.tensor( |
|
[[5, 6, 7, 8, 2, 13, 14, 15, 16], [5, 6, 7, 1, 2, 3, 4, 15, 16]] |
|
).unsqueeze(-1) |
|
mask = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2], [0, 0, 0, 1, 1, 1, 1, 2, 2]]) |
|
x, y, x_mask, y_mask, indices = Decoder.split_x_y(tokens, mask) |
|
new_tokens = Decoder.combine_x_y(x, y, x_mask, y_mask, indices) |
|
tokens[mask == 1] = 0 |
|
self.assertTrue(torch.equal(tokens, new_tokens)) |
|
|
|
def test_load_from_device(self): |
|
config = load_check_config("nano.json") |
|
original_encoder = Encoder(**config["model"]["encoder"]) |
|
|
|
with tempfile.TemporaryDirectory() as tempdir: |
|
torch.save(original_encoder.state_dict(), Path(tempdir) / ENCODER_FILENAME) |
|
with (Path(tempdir) / CONFIG_FILENAME).open("w") as f: |
|
json.dump(config, f) |
|
|
|
new_encoder = Encoder.load_from_folder(Path(tempdir)) |
|
|
|
for key, val in new_encoder.state_dict().items(): |
|
self.assertTrue(torch.equal(val, original_encoder.state_dict()[key])) |
|
|
|
def test_decoder_and_mask_static(self): |
|
patch_size = 4 |
|
ratio = 0.25 |
|
|
|
ds = Dataset(TIFS_FOLDER, False) |
|
tensor_batch = self.to_tensor_with_batch_d(ds[0]) |
|
self.assertTrue(tensor_batch[0].shape[1] == tensor_batch[0].shape[2]) |
|
for f in [batch_mask_time, batch_mask_space]: |
|
masked_output = f( |
|
*tensor_batch, |
|
encode_ratio=ratio, |
|
decode_ratio=ratio, |
|
mode=[("space", "DW")], |
|
decoder_mode=[("static", "LS")], |
|
patch_size=patch_size, |
|
) |
|
|
|
encoder = Encoder(embedding_size=32, num_heads=1) |
|
decoder = Decoder( |
|
encoder_embedding_size=32, |
|
decoder_embedding_size=32, |
|
num_heads=1, |
|
) |
|
encoder_output = encoder( |
|
masked_output.space_time_x, |
|
masked_output.space_x, |
|
masked_output.time_x, |
|
masked_output.static_x, |
|
masked_output.space_time_mask, |
|
masked_output.space_mask, |
|
masked_output.time_mask, |
|
masked_output.static_mask, |
|
masked_output.months.long(), |
|
patch_size=patch_size, |
|
) |
|
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, _ = encoder_output |
|
x, m = decoder.collapse_and_combine_hwtc( |
|
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m |
|
) |
|
x, _, _, _, _ = decoder.split_x_y(x, m) |
|
self.assertTrue(x.shape[1] == 1, x.shape) |
|
|
|
def test_token_exit_cfgs_single_exit_equivalency(self): |
|
self._token_exit_cfgs_single_exit_equivalency(0) |
|
self._token_exit_cfgs_single_exit_equivalency(6) |
|
self._token_exit_cfgs_single_exit_equivalency(12) |
|
|
|
@torch.no_grad() |
|
def _token_exit_cfgs_single_exit_equivalency(self, depth): |
|
embedding_size, patch_size = 16, 1 |
|
image_size = patch_size * 4 |
|
num_timesteps = 3 |
|
encoder = Encoder(embedding_size=embedding_size, num_heads=1, depth=12) |
|
encoder.eval() |
|
ds = Dataset(TIFS_FOLDER, False) |
|
for i in range(len(ds)): |
|
s_t_x, sp_x, t_x, st_x, months = self.to_tensor_with_batch_d(ds[i]) |
|
masked_output = batch_subset_mask_galileo( |
|
s_t_x, |
|
sp_x, |
|
t_x, |
|
st_x, |
|
months, |
|
encode_ratio=0.25, |
|
decode_ratio=0.25, |
|
patch_size=patch_size, |
|
image_size=image_size, |
|
num_timesteps=num_timesteps, |
|
augmentation_strategies=None, |
|
masking_probabilities=[1] * len(MASKING_MODES), |
|
masking_function=MaskingFunctions.SPACE, |
|
max_unmasking_channels=4, |
|
) |
|
|
|
|
|
|
|
|
|
token_exit_cfgs = { |
|
"S1": depth, |
|
"S2_RGB": depth, |
|
"S2_Red_Edge": depth, |
|
"S2_NIR_10m": depth, |
|
"S2_NIR_20m": depth, |
|
"S2_SWIR": depth, |
|
"NDVI": depth, |
|
"ERA5": depth, |
|
"TC": depth, |
|
"VIIRS": depth, |
|
"SRTM": depth, |
|
"DW": depth, |
|
"WC": depth, |
|
"LS": depth, |
|
"location": depth, |
|
"DW_static": depth, |
|
"WC_static": depth, |
|
} |
|
|
|
encoder_output_depth = encoder( |
|
masked_output.space_time_x, |
|
masked_output.space_x, |
|
masked_output.time_x, |
|
masked_output.static_x, |
|
torch.zeros_like(masked_output.space_time_mask), |
|
torch.zeros_like(masked_output.space_mask), |
|
torch.zeros_like(masked_output.time_mask), |
|
torch.zeros_like(masked_output.static_mask), |
|
masked_output.months.long(), |
|
patch_size=patch_size, |
|
exit_after=depth, |
|
) |
|
|
|
encoder_output_depth_varied = encoder( |
|
masked_output.space_time_x, |
|
masked_output.space_x, |
|
masked_output.time_x, |
|
masked_output.static_x, |
|
torch.zeros_like(masked_output.space_time_mask), |
|
torch.zeros_like(masked_output.space_mask), |
|
torch.zeros_like(masked_output.time_mask), |
|
torch.zeros_like(masked_output.static_mask), |
|
masked_output.months.long(), |
|
patch_size=patch_size, |
|
token_exit_cfg=token_exit_cfgs, |
|
exit_after=None, |
|
) |
|
|
|
|
|
self.assertTrue(torch.equal(encoder_output_depth_varied[0], encoder_output_depth[0])) |
|
|
|
|
|
self.assertTrue(torch.equal(encoder_output_depth_varied[1], encoder_output_depth[1])) |
|
|
|
|
|
self.assertTrue(torch.equal(encoder_output_depth_varied[2], encoder_output_depth[2])) |
|
|
|
|
|
self.assertTrue(torch.equal(encoder_output_depth_varied[3], encoder_output_depth[3])) |
|
|
|
def test_single_file_galileo_matches_galileo(self): |
|
org_model = Encoder.load_from_folder(DATA_FOLDER / "models/nano") |
|
sf_model = SingleFileEncoder.load_from_folder( |
|
DATA_FOLDER / "models/nano", device=torch.device("cpu") |
|
) |
|
|
|
for model_p, sf_model_p in zip(org_model.parameters(), sf_model.parameters()): |
|
self.assertTrue(torch.equal(model_p, sf_model_p)) |
|
|