csaybar's picture
Upload 5 files
d19ad5c verified
raw
history blame
4.09 kB
import pathlib
import safetensors.torch
import matplotlib.pyplot as plt
from sen2sr.models.opensr_baseline.swin import Swin2SR
from sen2sr.models.tricks import HardConstraint
from sen2sr.referencex2 import srmodel
# MLSTAC API -----------------------------------------------------------------------
def example_data(path: pathlib.Path, *args, **kwargs):
data_f = path / "example_data.safetensor"
sample = safetensors.torch.load_file(data_f)
return sample["lr"]
def trainable_model(path, device: str = "cpu", *args, **kwargs):
trainable_f = path / "model.safetensor"
# Load model parameters
sr_model_weights = safetensors.torch.load_file(trainable_f)
params = {
"img_size": (64, 64),
"in_channels": 10,
"out_channels": 6,
"embed_dim": 192,
"depths": [8] * 8,
"num_heads": [8] * 8,
"window_size": 4,
"mlp_ratio": 4.0,
"upscale": 1,
"resi_connection": "1conv",
"upsampler": "pixelshuffle",
}
sr_model = Swin2SR(**params)
sr_model.load_state_dict(sr_model_weights)
sr_model.to(device)
# Load HardConstraint
hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
hard_constraint = HardConstraint(
low_pass_mask=hard_constraint_weights["weights"].to(device),
bands= [0, 1, 2, 3, 4, 5],
device=device
)
return srmodel(sr_model=sr_model, hard_constraint=hard_constraint, device=device)
def compiled_model(path, device: str = "cpu", *args, **kwargs):
trainable_f = path / "model.safetensor"
# Load model parameters
sr_model_weights = safetensors.torch.load_file(trainable_f)
params = {
"img_size": (64, 64),
"in_channels": 10,
"out_channels": 6,
"embed_dim": 192,
"depths": [8] * 8,
"num_heads": [8] * 8,
"window_size": 4,
"mlp_ratio": 4.0,
"upscale": 1,
"resi_connection": "1conv",
"upsampler": "pixelshuffle",
}
sr_model = Swin2SR(**params)
sr_model.load_state_dict(sr_model_weights)
sr_model = sr_model.to(device)
sr_model = sr_model.eval()
for param in sr_model.parameters():
param.requires_grad = False
# Load HardConstraint
hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
hard_constraint = HardConstraint(
low_pass_mask=hard_constraint_weights["weights"].to(device),
bands= [0, 1, 2, 3, 4, 5],
device=device
)
hard_constraint = hard_constraint.eval()
for param in hard_constraint.parameters():
param.requires_grad = False
return srmodel(sr_model=sr_model, hard_constraint=hard_constraint, device=device)
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
# Load model
model = compiled_model(path, device)
# Load data
lr = example_data(path)
# Run model
sr = model(lr.to(device))
# Create the viz
lr_rgb = lr[0, [2, 1, 0]].cpu().numpy().transpose(1, 2, 0)
sr_rgb = sr[0, [2, 1, 0]].cpu().numpy().transpose(1, 2, 0)
lr_swirs = lr[0, [9, 8, 7]].cpu().numpy().transpose(1, 2, 0)
sr_swirs = sr[0, [9, 8, 7]].cpu().numpy().transpose(1, 2, 0)
lr_reds = lr[0, [6, 5, 4]].cpu().numpy().transpose(1, 2, 0)
sr_reds = sr[0, [6, 5, 4]].cpu().numpy().transpose(1, 2, 0)
#Display results
lr_slice = slice(16, 32+80)
hr_slice = slice(lr_slice.start*1, lr_slice.stop*1)
fig, ax = plt.subplots(3, 2, figsize=(8, 12))
ax = ax.flatten()
ax[0].imshow(lr_rgb[lr_slice]*2)
ax[0].set_title("LR RGB")
ax[1].imshow(sr_rgb[hr_slice]*2)
ax[1].set_title("SR RGB")
ax[2].imshow(lr_swirs[lr_slice]*2)
ax[2].set_title("LR SWIR")
ax[3].imshow(sr_swirs[hr_slice]*2)
ax[3].set_title("SR SWIR")
ax[4].imshow(lr_reds[lr_slice]*2)
ax[4].set_title("LR RED")
ax[5].imshow(sr_reds[hr_slice]*2)
ax[5].set_title("SR RED")
for a in ax:
a.axis("off")
fig.tight_layout()
return fig