|
from typing import cast |
|
|
|
import torch |
|
|
|
from src.data import ( |
|
SPACE_BAND_GROUPS_IDX, |
|
SPACE_TIME_BANDS_GROUPS_IDX, |
|
STATIC_BAND_GROUPS_IDX, |
|
TIME_BAND_GROUPS_IDX, |
|
) |
|
from src.data.dataset import ( |
|
SPACE_BANDS, |
|
SPACE_TIME_BANDS, |
|
STATIC_BANDS, |
|
TIME_BANDS, |
|
Normalizer, |
|
to_cartesian, |
|
) |
|
from src.data.earthengine.eo import ( |
|
DW_BANDS, |
|
ERA5_BANDS, |
|
LANDSCAN_BANDS, |
|
LOCATION_BANDS, |
|
S1_BANDS, |
|
S2_BANDS, |
|
SRTM_BANDS, |
|
TC_BANDS, |
|
VIIRS_BANDS, |
|
WC_BANDS, |
|
) |
|
from src.masking import MaskedOutput |
|
|
|
DEFAULT_MONTH = 5 |
|
|
|
|
|
def construct_galileo_input( |
|
s1: torch.Tensor | None = None, |
|
s2: torch.Tensor | None = None, |
|
era5: torch.Tensor | None = None, |
|
tc: torch.Tensor | None = None, |
|
viirs: torch.Tensor | None = None, |
|
srtm: torch.Tensor | None = None, |
|
dw: torch.Tensor | None = None, |
|
wc: torch.Tensor | None = None, |
|
landscan: torch.Tensor | None = None, |
|
latlon: torch.Tensor | None = None, |
|
months: torch.Tensor | None = None, |
|
normalize: bool = False, |
|
): |
|
space_time_inputs = [s1, s2] |
|
time_inputs = [era5, tc, viirs] |
|
space_inputs = [srtm, dw, wc] |
|
static_inputs = [landscan, latlon] |
|
devices = [ |
|
x.device |
|
for x in space_time_inputs + time_inputs + space_inputs + static_inputs |
|
if x is not None |
|
] |
|
|
|
if len(devices) == 0: |
|
raise ValueError("At least one input must be not None") |
|
if not all(devices[0] == device for device in devices): |
|
raise ValueError("Received tensors on multiple devices") |
|
device = devices[0] |
|
|
|
|
|
timesteps_list = [x.shape[2] for x in space_time_inputs if x is not None] + [ |
|
x.shape[1] for x in time_inputs if x is not None |
|
] |
|
height_list = [x.shape[0] for x in space_time_inputs if x is not None] + [ |
|
x.shape[0] for x in space_inputs if x is not None |
|
] |
|
width_list = [x.shape[1] for x in space_time_inputs if x is not None] + [ |
|
x.shape[1] for x in space_inputs if x is not None |
|
] |
|
|
|
if len(timesteps_list) > 0: |
|
if not all(timesteps_list[0] == timestep for timestep in timesteps_list): |
|
raise ValueError("Inconsistent number of timesteps per input") |
|
t = timesteps_list[0] |
|
else: |
|
t = 1 |
|
|
|
if len(height_list) > 0: |
|
if not all(height_list[0] == height for height in height_list): |
|
raise ValueError("Inconsistent heights per input") |
|
if not all(width_list[0] == width for width in width_list): |
|
raise ValueError("Inconsistent widths per input") |
|
h = height_list[0] |
|
w = width_list[0] |
|
else: |
|
h, w = 1, 1 |
|
|
|
|
|
s_t_x = torch.zeros((h, w, t, len(SPACE_TIME_BANDS)), dtype=torch.float, device=device) |
|
s_t_m = torch.ones( |
|
(h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), dtype=torch.float, device=device |
|
) |
|
sp_x = torch.zeros((h, w, len(SPACE_BANDS)), dtype=torch.float, device=device) |
|
sp_m = torch.ones((h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=torch.float, device=device) |
|
t_x = torch.zeros((t, len(TIME_BANDS)), dtype=torch.float, device=device) |
|
t_m = torch.ones((t, len(TIME_BAND_GROUPS_IDX)), dtype=torch.float, device=device) |
|
st_x = torch.zeros((len(STATIC_BANDS)), dtype=torch.float, device=device) |
|
st_m = torch.ones((len(STATIC_BAND_GROUPS_IDX)), dtype=torch.float, device=device) |
|
|
|
for x, bands_list, group_key in zip([s1, s2], [S1_BANDS, S2_BANDS], ["S1", "S2"]): |
|
if x is not None: |
|
indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in bands_list] |
|
groups_idx = [ |
|
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if group_key in key |
|
] |
|
s_t_x[:, :, :, indices] = x |
|
s_t_m[:, :, :, groups_idx] = 0 |
|
|
|
for x, bands_list, group_key in zip( |
|
[srtm, dw, wc], [SRTM_BANDS, DW_BANDS, WC_BANDS], ["SRTM", "DW", "WC"] |
|
): |
|
if x is not None: |
|
indices = [idx for idx, val in enumerate(SPACE_BANDS) if val in bands_list] |
|
groups_idx = [idx for idx, key in enumerate(SPACE_BAND_GROUPS_IDX) if group_key in key] |
|
sp_x[:, :, indices] = x |
|
sp_m[:, :, groups_idx] = 0 |
|
|
|
for x, bands_list, group_key in zip( |
|
[era5, tc, viirs], [ERA5_BANDS, TC_BANDS, VIIRS_BANDS], ["ERA5", "TC", "VIIRS"] |
|
): |
|
if x is not None: |
|
indices = [idx for idx, val in enumerate(TIME_BANDS) if val in bands_list] |
|
groups_idx = [idx for idx, key in enumerate(TIME_BAND_GROUPS_IDX) if group_key in key] |
|
t_x[:, indices] = x |
|
t_m[:, groups_idx] = 0 |
|
|
|
for x, bands_list, group_key in zip( |
|
[landscan, latlon], [LANDSCAN_BANDS, LOCATION_BANDS], ["LS", "location"] |
|
): |
|
if x is not None: |
|
if group_key == "location": |
|
|
|
x = cast(torch.Tensor, to_cartesian(x[0], x[1])) |
|
indices = [idx for idx, val in enumerate(STATIC_BANDS) if val in bands_list] |
|
groups_idx = [ |
|
idx for idx, key in enumerate(STATIC_BAND_GROUPS_IDX) if group_key in key |
|
] |
|
st_x[indices] = x |
|
st_m[groups_idx] = 0 |
|
|
|
if months is None: |
|
months = torch.ones((t), dtype=torch.long, device=device) * DEFAULT_MONTH |
|
else: |
|
if months.shape[0] != t: |
|
raise ValueError("Incorrect number of input months") |
|
|
|
if normalize: |
|
normalizer = Normalizer(std=False) |
|
s_t_x = torch.from_numpy(normalizer(s_t_x.cpu().numpy())).to(device) |
|
sp_x = torch.from_numpy(normalizer(sp_x.cpu().numpy())).to(device) |
|
t_x = torch.from_numpy(normalizer(t_x.cpu().numpy())).to(device) |
|
st_x = torch.from_numpy(normalizer(st_x.cpu().numpy())).to(device) |
|
|
|
return MaskedOutput( |
|
space_time_x=s_t_x, |
|
space_time_mask=s_t_m, |
|
space_x=sp_x, |
|
space_mask=sp_m, |
|
time_x=t_x, |
|
time_mask=t_m, |
|
static_x=st_x, |
|
static_mask=st_m, |
|
months=months, |
|
) |
|
|