|
from torch import nn |
|
from utils.par_embed import ParallelVarPatchEmbed |
|
from utils.pos_embed import get_1d_sincos_pos_embed_from_grid,get_2d_sincos_pos_embed |
|
from utils.feature_extractor import IterativeUpsampleStepParams |
|
from utils.head import Head |
|
import torch |
|
from functools import lru_cache |
|
from timm.models.vision_transformer import Block, PatchEmbed, trunc_normal_ |
|
import numpy as np |
|
from typing import List, Literal |
|
|
|
class ClimaXLegacy(nn.Module): |
|
"""Implements the ClimaX model as described in the paper, |
|
https://arxiv.org/abs/2301.10343 |
|
|
|
Args: |
|
default_vars (list): list of default variables to be used for training |
|
img_size (list): image size of the input data |
|
patch_size (int): patch size of the input data |
|
embed_dim (int): embedding dimension |
|
depth (int): number of transformer layers |
|
decoder_depth (int): number of decoder layers |
|
num_heads (int): number of attention heads |
|
mlp_ratio (float): ratio of mlp hidden dimension to embedding dimension |
|
drop_path (float): stochastic depth rate |
|
drop_rate (float): dropout rate |
|
parallel_patch_embed (bool): whether to use parallel patch embedding |
|
upsampling_steps (List[IterativeUpsampleStepParams]): Each dict represents a step of upsampling. |
|
scale_factor determines by what scale the 2D dimensions grow. |
|
new_channel_dim determines what the channel dimension should at the end of that step. |
|
feature_dim determines what the channel depth should be for the feature set that is stacked on at start of step |
|
""" |
|
|
|
def __init__( |
|
self, |
|
default_vars, |
|
img_size=[32, 64], |
|
patch_size=2, |
|
embed_dim=1024, |
|
depth=8, |
|
decoder_depth=2, |
|
num_heads=16, |
|
mlp_ratio=4.0, |
|
drop_path=0.1, |
|
drop_rate=0.1, |
|
out_dim=1, |
|
parallel_patch_embed=False, |
|
pretrained=False, |
|
upsampling_steps: List[IterativeUpsampleStepParams] = [{"step_scale_factor": 2, "new_channel_dim" : 1, "feature_dim": 1, "block_count": 1}], |
|
feature_extractor_type: Literal["simple", "res-net"] = "res-net", |
|
double=False, |
|
**kwargs |
|
): |
|
super().__init__() |
|
|
|
|
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.default_vars = default_vars |
|
self.parallel_patch_embed = parallel_patch_embed |
|
|
|
if self.parallel_patch_embed: |
|
self.token_embeds = ParallelVarPatchEmbed(len(default_vars), img_size, patch_size, embed_dim) |
|
self.num_patches = self.token_embeds.num_patches |
|
else: |
|
self.token_embeds = nn.ModuleList( |
|
[PatchEmbed(img_size, patch_size, 1, embed_dim) for i in range(len(default_vars))] |
|
) |
|
self.num_patches = self.token_embeds[0].num_patches |
|
|
|
|
|
|
|
self.var_embed, self.var_map = self.create_var_embedding(embed_dim) |
|
|
|
|
|
self.var_query = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True) |
|
self.var_agg = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) |
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True) |
|
self.lead_time_embed = nn.Linear(1, embed_dim) |
|
|
|
|
|
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] |
|
self.blocks = nn.ModuleList( |
|
[ |
|
Block( |
|
embed_dim, |
|
num_heads, |
|
mlp_ratio, |
|
qkv_bias=True, |
|
drop_path=dpr[i], |
|
norm_layer=nn.LayerNorm, |
|
proj_drop=drop_rate, |
|
) |
|
for i in range(depth) |
|
] |
|
) |
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
|
|
|
|
self.out_dim = out_dim |
|
self.head = Head(embed_dim, decoder_depth, img_size, patch_size, upsampling_steps, out_dim, feature_extractor_type,double=double) |
|
|
|
|
|
|
|
self.initialize_weights() |
|
if pretrained: |
|
ckpt = torch.load(pretrained,map_location='cpu') |
|
state_dict = {k.replace('net.',''):v for k,v in ckpt['state_dict'].items()} |
|
current_dict = self.state_dict() |
|
for k in state_dict: |
|
if k in current_dict and current_dict[k].shape != state_dict[k].shape: |
|
print("do not load, shape mismatch:",k) |
|
state_dict[k] = current_dict[k] |
|
|
|
self.load_state_dict(state_dict,strict=False) |
|
|
|
self.is_double = double |
|
|
|
def initialize_weights(self): |
|
|
|
pos_embed = get_2d_sincos_pos_embed( |
|
self.pos_embed.shape[-1], |
|
int(self.img_size[0] / self.patch_size), |
|
int(self.img_size[1] / self.patch_size), |
|
cls_token=False, |
|
) |
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
var_embed = get_1d_sincos_pos_embed_from_grid(self.var_embed.shape[-1], np.arange(len(self.default_vars))) |
|
self.var_embed.data.copy_(torch.from_numpy(var_embed).float().unsqueeze(0)) |
|
|
|
|
|
if self.parallel_patch_embed: |
|
for i in range(len(self.token_embeds.proj_weights)): |
|
w = self.token_embeds.proj_weights[i].data |
|
trunc_normal_(w.view([w.shape[0], -1]), std=0.02) |
|
else: |
|
for i in range(len(self.token_embeds)): |
|
w = self.token_embeds[i].proj.weight.data |
|
trunc_normal_(w.view([w.shape[0], -1]), std=0.02) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def create_var_embedding(self, dim): |
|
var_embed = nn.Parameter(torch.zeros(1, len(self.default_vars), dim), requires_grad=True) |
|
|
|
var_map = {} |
|
idx = 0 |
|
for var in self.default_vars: |
|
var_map[var] = idx |
|
idx += 1 |
|
return var_embed, var_map |
|
|
|
@lru_cache(maxsize=None) |
|
def get_var_ids(self, vars, device): |
|
ids = np.array([self.var_map[var] for var in vars]) |
|
return torch.from_numpy(ids).to(device) |
|
|
|
def get_var_emb(self, var_emb, vars): |
|
ids = self.get_var_ids(vars, var_emb.device) |
|
return var_emb[:, ids, :] |
|
|
|
def unpatchify(self, x: torch.Tensor, h=None, w=None): |
|
""" |
|
x: (B, L, V * patch_size**2) |
|
return imgs: (B, V, H, W) |
|
""" |
|
p = self.patch_size |
|
c = self.out_dim |
|
h = self.img_size[0] // p if h is None else h // p |
|
w = self.img_size[1] // p if w is None else w // p |
|
assert h * w == x.shape[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) |
|
x = torch.einsum("nhwpqc->nchpwq", x) |
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) |
|
return imgs |
|
|
|
def aggregate_variables(self, x: torch.Tensor): |
|
""" |
|
x: B, V, L, D |
|
""" |
|
b, _, l, _ = x.shape |
|
x = torch.einsum("bvld->blvd", x) |
|
x = x.flatten(0, 1) |
|
|
|
var_query = self.var_query.repeat_interleave(x.shape[0], dim=0) |
|
x, _ = self.var_agg(var_query, x, x) |
|
x = x.squeeze() |
|
|
|
x = x.unflatten(dim=0, sizes=(b, l)) |
|
return x |
|
|
|
def forward_head(self, x: torch.Tensor, original_image: torch.Tensor): |
|
return self.head.forward(x, original_image) |
|
|
|
def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables): |
|
|
|
|
|
if isinstance(variables, list): |
|
variables = tuple(variables) |
|
|
|
|
|
embeds = [] |
|
var_ids = self.get_var_ids(variables, x.device) |
|
|
|
if self.parallel_patch_embed: |
|
x = self.token_embeds(x, var_ids) |
|
else: |
|
for i in range(len(var_ids)): |
|
id = var_ids[i] |
|
embeds.append(self.token_embeds[id](x[:, i : i + 1])) |
|
x = torch.stack(embeds, dim=1) |
|
|
|
|
|
var_embed = self.get_var_emb(self.var_embed, variables) |
|
x = x + var_embed.unsqueeze(2) |
|
|
|
|
|
x = self.aggregate_variables(x) |
|
|
|
|
|
x = x + self.pos_embed |
|
|
|
|
|
lead_time_emb = self.lead_time_embed(lead_times.unsqueeze(-1)) |
|
lead_time_emb = lead_time_emb.unsqueeze(1) |
|
x = x + lead_time_emb |
|
|
|
x = self.pos_drop(x) |
|
|
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
x = self.norm(x) |
|
|
|
return x |
|
|
|
|
|
def forward(self, x, lead_times=None, variables=None): |
|
x = x.to(memory_format=torch.channels_last) |
|
if self.training or( x.shape[2] == self.img_size[0] and x.shape[3] == self.img_size[1] ): |
|
return self.local_forward(x, lead_times, variables) |
|
else: |
|
assert x.shape[2] % self.img_size[0] ==0 and x.shape[3] % self.img_size[1] ==0 |
|
nh = x.shape[2] // self.img_size[0] |
|
nw = x.shape[3] // self.img_size[1] |
|
output = torch.zeros(x.shape[0],self.out_dim,x.shape[2],x.shape[3]).to(x) |
|
for i in range(nh): |
|
for j in range(nw): |
|
local_y = self.local_forward(x[...,i*self.img_size[0]:(i+1)*self.img_size[0],j*self.img_size[1]:(j+1)*self.img_size[1]], lead_times, variables) |
|
output[...,i*self.img_size[0]:(i+1)*self.img_size[0],j*self.img_size[1]:(j+1)*self.img_size[1]]=local_y |
|
return output |
|
|
|
def local_forward(self, x, lead_times=None, variables=None): |
|
"""Forward pass through the model. |
|
|
|
Args: |
|
x: `[B, Vi, H, W]` shape. Input weather/climate variables |
|
y: `[B, Vo, H, W]` shape. Target weather/climate variables |
|
lead_times: `[B]` shape. Forecasting lead times of each element of the batch. |
|
|
|
Returns: |
|
loss (list): Different metrics. |
|
preds (torch.Tensor): `[B, Vo, H, W]` shape. Predicted weather/climate variables. |
|
""" |
|
x_raw = x |
|
if variables is None: |
|
variables = self.default_vars |
|
if len(x.shape) == 5: |
|
assert x.shape[1] == 1 |
|
x = x.squeeze(1) |
|
n = x.shape[0] |
|
if lead_times is None: |
|
lead_times = torch.ones(n).to(x) |
|
if self.is_double: |
|
x0,x1 = x.chunk(2,dim=1) |
|
assert x0.shape[1] == 3 |
|
x = torch.cat([x0,x1]) |
|
lead_times = torch.cat([lead_times,lead_times]) |
|
out_transformers = self.forward_encoder(x, lead_times, variables) |
|
x0,x1 =out_transformers.chunk(2,dim=0) |
|
out_transformers = torch.cat([x0,x1],dim=-1) |
|
else: |
|
out_transformers = self.forward_encoder(x, lead_times, variables) |
|
preds = self.forward_head(x=out_transformers, original_image=x_raw) |
|
|
|
preds = self.unpatchify(preds) |
|
return preds |
|
return loss, preds |
|
|
|
def evaluate(self, x, y, lead_times, variables, out_variables, transform, metrics, lat, clim, log_postfix): |
|
_, preds = self.forward(x, y, lead_times, variables, out_variables, metric=None, lat=lat) |
|
return [m(preds, y, transform, out_variables, lat, clim, log_postfix) for m in metrics] |
|
|
|
|
|
if __name__ == '__main__': |
|
variables = ['R','G','B'] |
|
z = ClimaXLegacy(img_size=[512,512],patch_size=16,default_vars=variables) |
|
x = torch.rand(3,3,512,512) |
|
z.eval() |
|
x = torch.rand(3,3,1024,1024) |
|
z(x) |