galileo / src /data /utils.py
openfree's picture
Deploy from GitHub repository
3dcb328 verified
from typing import cast
import torch
from src.data import (
SPACE_BAND_GROUPS_IDX,
SPACE_TIME_BANDS_GROUPS_IDX,
STATIC_BAND_GROUPS_IDX,
TIME_BAND_GROUPS_IDX,
)
from src.data.dataset import (
SPACE_BANDS,
SPACE_TIME_BANDS,
STATIC_BANDS,
TIME_BANDS,
Normalizer,
to_cartesian,
)
from src.data.earthengine.eo import (
DW_BANDS,
ERA5_BANDS,
LANDSCAN_BANDS,
LOCATION_BANDS,
S1_BANDS,
S2_BANDS,
SRTM_BANDS,
TC_BANDS,
VIIRS_BANDS,
WC_BANDS,
)
from src.masking import MaskedOutput
DEFAULT_MONTH = 5
def construct_galileo_input(
s1: torch.Tensor | None = None, # [H, W, T, D]
s2: torch.Tensor | None = None, # [H, W, T, D]
era5: torch.Tensor | None = None, # [T, D]
tc: torch.Tensor | None = None, # [T, D]
viirs: torch.Tensor | None = None, # [T, D]
srtm: torch.Tensor | None = None, # [H, W, D]
dw: torch.Tensor | None = None, # [H, W, D]
wc: torch.Tensor | None = None, # [H, W, D]
landscan: torch.Tensor | None = None, # [D]
latlon: torch.Tensor | None = None, # [D]
months: torch.Tensor | None = None, # [T]
normalize: bool = False,
):
space_time_inputs = [s1, s2]
time_inputs = [era5, tc, viirs]
space_inputs = [srtm, dw, wc]
static_inputs = [landscan, latlon]
devices = [
x.device
for x in space_time_inputs + time_inputs + space_inputs + static_inputs
if x is not None
]
if len(devices) == 0:
raise ValueError("At least one input must be not None")
if not all(devices[0] == device for device in devices):
raise ValueError("Received tensors on multiple devices")
device = devices[0]
# first, check all the input shapes are consistent
timesteps_list = [x.shape[2] for x in space_time_inputs if x is not None] + [
x.shape[1] for x in time_inputs if x is not None
]
height_list = [x.shape[0] for x in space_time_inputs if x is not None] + [
x.shape[0] for x in space_inputs if x is not None
]
width_list = [x.shape[1] for x in space_time_inputs if x is not None] + [
x.shape[1] for x in space_inputs if x is not None
]
if len(timesteps_list) > 0:
if not all(timesteps_list[0] == timestep for timestep in timesteps_list):
raise ValueError("Inconsistent number of timesteps per input")
t = timesteps_list[0]
else:
t = 1
if len(height_list) > 0:
if not all(height_list[0] == height for height in height_list):
raise ValueError("Inconsistent heights per input")
if not all(width_list[0] == width for width in width_list):
raise ValueError("Inconsistent widths per input")
h = height_list[0]
w = width_list[0]
else:
h, w = 1, 1
# now, we can construct our empty input tensors. By default, everything is masked
s_t_x = torch.zeros((h, w, t, len(SPACE_TIME_BANDS)), dtype=torch.float, device=device)
s_t_m = torch.ones(
(h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), dtype=torch.float, device=device
)
sp_x = torch.zeros((h, w, len(SPACE_BANDS)), dtype=torch.float, device=device)
sp_m = torch.ones((h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=torch.float, device=device)
t_x = torch.zeros((t, len(TIME_BANDS)), dtype=torch.float, device=device)
t_m = torch.ones((t, len(TIME_BAND_GROUPS_IDX)), dtype=torch.float, device=device)
st_x = torch.zeros((len(STATIC_BANDS)), dtype=torch.float, device=device)
st_m = torch.ones((len(STATIC_BAND_GROUPS_IDX)), dtype=torch.float, device=device)
for x, bands_list, group_key in zip([s1, s2], [S1_BANDS, S2_BANDS], ["S1", "S2"]):
if x is not None:
indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in bands_list]
groups_idx = [
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if group_key in key
]
s_t_x[:, :, :, indices] = x
s_t_m[:, :, :, groups_idx] = 0
for x, bands_list, group_key in zip(
[srtm, dw, wc], [SRTM_BANDS, DW_BANDS, WC_BANDS], ["SRTM", "DW", "WC"]
):
if x is not None:
indices = [idx for idx, val in enumerate(SPACE_BANDS) if val in bands_list]
groups_idx = [idx for idx, key in enumerate(SPACE_BAND_GROUPS_IDX) if group_key in key]
sp_x[:, :, indices] = x
sp_m[:, :, groups_idx] = 0
for x, bands_list, group_key in zip(
[era5, tc, viirs], [ERA5_BANDS, TC_BANDS, VIIRS_BANDS], ["ERA5", "TC", "VIIRS"]
):
if x is not None:
indices = [idx for idx, val in enumerate(TIME_BANDS) if val in bands_list]
groups_idx = [idx for idx, key in enumerate(TIME_BAND_GROUPS_IDX) if group_key in key]
t_x[:, indices] = x
t_m[:, groups_idx] = 0
for x, bands_list, group_key in zip(
[landscan, latlon], [LANDSCAN_BANDS, LOCATION_BANDS], ["LS", "location"]
):
if x is not None:
if group_key == "location":
# transform latlon to cartesian
x = cast(torch.Tensor, to_cartesian(x[0], x[1]))
indices = [idx for idx, val in enumerate(STATIC_BANDS) if val in bands_list]
groups_idx = [
idx for idx, key in enumerate(STATIC_BAND_GROUPS_IDX) if group_key in key
]
st_x[indices] = x
st_m[groups_idx] = 0
if months is None:
months = torch.ones((t), dtype=torch.long, device=device) * DEFAULT_MONTH
else:
if months.shape[0] != t:
raise ValueError("Incorrect number of input months")
if normalize:
normalizer = Normalizer(std=False)
s_t_x = torch.from_numpy(normalizer(s_t_x.cpu().numpy())).to(device)
sp_x = torch.from_numpy(normalizer(sp_x.cpu().numpy())).to(device)
t_x = torch.from_numpy(normalizer(t_x.cpu().numpy())).to(device)
st_x = torch.from_numpy(normalizer(st_x.cpu().numpy())).to(device)
return MaskedOutput(
space_time_x=s_t_x,
space_time_mask=s_t_m,
space_x=sp_x,
space_mask=sp_m,
time_x=t_x,
time_mask=t_m,
static_x=st_x,
static_mask=st_m,
months=months,
)