climax-xview / climax.py
jacklishufan's picture
init commit
844f7c0
from torch import nn
from utils.par_embed import ParallelVarPatchEmbed
from utils.pos_embed import get_1d_sincos_pos_embed_from_grid,get_2d_sincos_pos_embed
from utils.feature_extractor import IterativeUpsampleStepParams
from utils.head import Head
import torch
from functools import lru_cache
from timm.models.vision_transformer import Block, PatchEmbed, trunc_normal_
import numpy as np
from typing import List, Literal
class ClimaXLegacy(nn.Module):
"""Implements the ClimaX model as described in the paper,
https://arxiv.org/abs/2301.10343
Args:
default_vars (list): list of default variables to be used for training
img_size (list): image size of the input data
patch_size (int): patch size of the input data
embed_dim (int): embedding dimension
depth (int): number of transformer layers
decoder_depth (int): number of decoder layers
num_heads (int): number of attention heads
mlp_ratio (float): ratio of mlp hidden dimension to embedding dimension
drop_path (float): stochastic depth rate
drop_rate (float): dropout rate
parallel_patch_embed (bool): whether to use parallel patch embedding
upsampling_steps (List[IterativeUpsampleStepParams]): Each dict represents a step of upsampling.
scale_factor determines by what scale the 2D dimensions grow.
new_channel_dim determines what the channel dimension should at the end of that step.
feature_dim determines what the channel depth should be for the feature set that is stacked on at start of step
"""
def __init__(
self,
default_vars,
img_size=[32, 64],
patch_size=2,
embed_dim=1024,
depth=8,
decoder_depth=2,
num_heads=16,
mlp_ratio=4.0,
drop_path=0.1,
drop_rate=0.1,
out_dim=1,
parallel_patch_embed=False,
pretrained=False,
upsampling_steps: List[IterativeUpsampleStepParams] = [{"step_scale_factor": 2, "new_channel_dim" : 1, "feature_dim": 1, "block_count": 1}],
feature_extractor_type: Literal["simple", "res-net"] = "res-net",
double=False,
**kwargs
):
super().__init__()
# TODO: remove time_history parameter
self.img_size = img_size
self.patch_size = patch_size
self.default_vars = default_vars
self.parallel_patch_embed = parallel_patch_embed
# variable tokenization: separate embedding layer for each input variable
if self.parallel_patch_embed:
self.token_embeds = ParallelVarPatchEmbed(len(default_vars), img_size, patch_size, embed_dim)
self.num_patches = self.token_embeds.num_patches
else:
self.token_embeds = nn.ModuleList(
[PatchEmbed(img_size, patch_size, 1, embed_dim) for i in range(len(default_vars))]
)
self.num_patches = self.token_embeds[0].num_patches
# variable embedding to denote which variable each token belongs to
# helps in aggregating variables
self.var_embed, self.var_map = self.create_var_embedding(embed_dim)
# variable aggregation: a learnable query and a single-layer cross attention
self.var_query = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
self.var_agg = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
# positional embedding and lead time embedding
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
self.lead_time_embed = nn.Linear(1, embed_dim)
# --------------------------------------------------------------------------
# ViT backbone
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=True,
drop_path=dpr[i],
norm_layer=nn.LayerNorm,
proj_drop=drop_rate,
)
for i in range(depth)
]
)
self.norm = nn.LayerNorm(embed_dim)
# --------------------------------------------------------------------------
# prediction head
self.out_dim = out_dim
self.head = Head(embed_dim, decoder_depth, img_size, patch_size, upsampling_steps, out_dim, feature_extractor_type,double=double)
# --------------------------------------------------------------------------
self.initialize_weights()
if pretrained:
ckpt = torch.load(pretrained,map_location='cpu')
state_dict = {k.replace('net.',''):v for k,v in ckpt['state_dict'].items()}
current_dict = self.state_dict()
for k in state_dict:
if k in current_dict and current_dict[k].shape != state_dict[k].shape:
print("do not load, shape mismatch:",k)
state_dict[k] = current_dict[k]
self.load_state_dict(state_dict,strict=False)
self.is_double = double
def initialize_weights(self):
# initialize pos_emb and var_emb
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.img_size[0] / self.patch_size),
int(self.img_size[1] / self.patch_size),
cls_token=False,
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
var_embed = get_1d_sincos_pos_embed_from_grid(self.var_embed.shape[-1], np.arange(len(self.default_vars)))
self.var_embed.data.copy_(torch.from_numpy(var_embed).float().unsqueeze(0))
# token embedding layer
if self.parallel_patch_embed:
for i in range(len(self.token_embeds.proj_weights)):
w = self.token_embeds.proj_weights[i].data
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
else:
for i in range(len(self.token_embeds)):
w = self.token_embeds[i].proj.weight.data
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def create_var_embedding(self, dim):
var_embed = nn.Parameter(torch.zeros(1, len(self.default_vars), dim), requires_grad=True)
# TODO: create a mapping from var --> idx
var_map = {}
idx = 0
for var in self.default_vars:
var_map[var] = idx
idx += 1
return var_embed, var_map
@lru_cache(maxsize=None)
def get_var_ids(self, vars, device):
ids = np.array([self.var_map[var] for var in vars])
return torch.from_numpy(ids).to(device)
def get_var_emb(self, var_emb, vars):
ids = self.get_var_ids(vars, var_emb.device)
return var_emb[:, ids, :]
def unpatchify(self, x: torch.Tensor, h=None, w=None):
"""
x: (B, L, V * patch_size**2)
return imgs: (B, V, H, W)
"""
p = self.patch_size
c = self.out_dim
h = self.img_size[0] // p if h is None else h // p
w = self.img_size[1] // p if w is None else w // p
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def aggregate_variables(self, x: torch.Tensor):
"""
x: B, V, L, D
"""
b, _, l, _ = x.shape
x = torch.einsum("bvld->blvd", x)
x = x.flatten(0, 1) # BxL, V, D
var_query = self.var_query.repeat_interleave(x.shape[0], dim=0)
x, _ = self.var_agg(var_query, x, x) # BxL, D
x = x.squeeze()
x = x.unflatten(dim=0, sizes=(b, l)) # B, L, D
return x
def forward_head(self, x: torch.Tensor, original_image: torch.Tensor):
return self.head.forward(x, original_image)
def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables):
# x: `[B, V, H, W]` shape.
if isinstance(variables, list):
variables = tuple(variables)
# tokenize each variable separately
embeds = []
var_ids = self.get_var_ids(variables, x.device)
if self.parallel_patch_embed:
x = self.token_embeds(x, var_ids) # B, V, L, D
else:
for i in range(len(var_ids)):
id = var_ids[i]
embeds.append(self.token_embeds[id](x[:, i : i + 1]))
x = torch.stack(embeds, dim=1) # B, V, L, D
# add variable embedding
var_embed = self.get_var_emb(self.var_embed, variables)
x = x + var_embed.unsqueeze(2) # B, V, L, D
# variable aggregation
x = self.aggregate_variables(x) # B, L, D
# add pos embedding
x = x + self.pos_embed
# add lead time embedding
lead_time_emb = self.lead_time_embed(lead_times.unsqueeze(-1)) # B, D
lead_time_emb = lead_time_emb.unsqueeze(1)
x = x + lead_time_emb # B, L, D
x = self.pos_drop(x)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def forward(self, x, lead_times=None, variables=None):
x = x.to(memory_format=torch.channels_last)
if self.training or( x.shape[2] == self.img_size[0] and x.shape[3] == self.img_size[1] ):
return self.local_forward(x, lead_times, variables)
else:
assert x.shape[2] % self.img_size[0] ==0 and x.shape[3] % self.img_size[1] ==0
nh = x.shape[2] // self.img_size[0]
nw = x.shape[3] // self.img_size[1]
output = torch.zeros(x.shape[0],self.out_dim,x.shape[2],x.shape[3]).to(x)
for i in range(nh):
for j in range(nw):
local_y = self.local_forward(x[...,i*self.img_size[0]:(i+1)*self.img_size[0],j*self.img_size[1]:(j+1)*self.img_size[1]], lead_times, variables)
output[...,i*self.img_size[0]:(i+1)*self.img_size[0],j*self.img_size[1]:(j+1)*self.img_size[1]]=local_y
return output
def local_forward(self, x, lead_times=None, variables=None):
"""Forward pass through the model.
Args:
x: `[B, Vi, H, W]` shape. Input weather/climate variables
y: `[B, Vo, H, W]` shape. Target weather/climate variables
lead_times: `[B]` shape. Forecasting lead times of each element of the batch.
Returns:
loss (list): Different metrics.
preds (torch.Tensor): `[B, Vo, H, W]` shape. Predicted weather/climate variables.
"""
x_raw = x
if variables is None:
variables = self.default_vars
if len(x.shape) == 5: # N T C H W -> N C H W (dropping time)
assert x.shape[1] == 1
x = x.squeeze(1)
n = x.shape[0]
if lead_times is None:
lead_times = torch.ones(n).to(x)
if self.is_double:
x0,x1 = x.chunk(2,dim=1)
assert x0.shape[1] == 3
x = torch.cat([x0,x1])
lead_times = torch.cat([lead_times,lead_times])
out_transformers = self.forward_encoder(x, lead_times, variables) # B, L, D
x0,x1 =out_transformers.chunk(2,dim=0)
out_transformers = torch.cat([x0,x1],dim=-1)
else:
out_transformers = self.forward_encoder(x, lead_times, variables) # B, L, D
preds = self.forward_head(x=out_transformers, original_image=x_raw) # B, L, V*p
preds = self.unpatchify(preds)
return preds
return loss, preds
def evaluate(self, x, y, lead_times, variables, out_variables, transform, metrics, lat, clim, log_postfix):
_, preds = self.forward(x, y, lead_times, variables, out_variables, metric=None, lat=lat)
return [m(preds, y, transform, out_variables, lat, clim, log_postfix) for m in metrics]
if __name__ == '__main__':
variables = ['R','G','B']
z = ClimaXLegacy(img_size=[512,512],patch_size=16,default_vars=variables)
x = torch.rand(3,3,512,512)
z.eval()
x = torch.rand(3,3,1024,1024)
z(x)