simon-donike's picture
Update NonReference_RGBN_x4/load.py
2a715d5 verified
raw
history blame
3.61 kB
import pathlib
import opensr_model
from omegaconf import OmegaConf
import safetensors.torch
import matplotlib.pyplot as plt
from sen2sr.models.tricks import HardConstraint
from sen2sr.nonreference import srmodel
config = {
"apply_normalization": False,
"ckpt_version": "opensr_10m_v4_v6.ckpt",
"encode_conditioning": True,
"denoiser_settings": {
"linear_start": 0.0015,
"linear_end": 0.0155,
"timesteps": 1000,
"sampling_eta": 1.0,
"sampling_steps":200,
"sampling_temperature":1.,
},
"first_stage_config": {
"embed_dim": 4,
"double_z": True,
"z_channels": 4,
"resolution": 256,
"in_channels": 4,
"out_ch": 4,
"ch": 128,
"ch_mult": [1, 2, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0
},
"cond_stage_config": {
"image_size": 64,
"in_channels": 8,
"model_channels": 160,
"out_channels": 4,
"num_res_blocks": 2,
"attention_resolutions": [16, 8],
"channel_mult": [1, 2, 2, 4],
"num_head_channels": 32
},
"other": {
"concat_mode": True,
"cond_stage_trainable": False,
"first_stage_key": "image",
"cond_stage_key": "LR_image"
}
}
# MLSTAC API -----------------------------------------------------------------------
def example_data(path: pathlib.Path, device: str = "cuda:0", *args, **kwargs):
data_f = path / "example_data.safetensor"
sample = safetensors.torch.load_file(data_f)
return sample["lr"].to(device), sample["hr"].to(device)
def trainable_model(*args, **kwargs):
print("Trainable model is not available for this model.")
return None
def compiled_model(path, device: str = "cuda:0", harconstraint_mode: bool = True, *args, **kwargs):
trainable_f = path / "model.safetensors"
# Load model parameters
sr_model_weights = safetensors.torch.load_file(trainable_f)
sr_model = opensr_model.SRLatentDiffusion(config=OmegaConf.create(config), device=device)
sr_model.model.load_state_dict(sr_model_weights, strict=True)
sr_model.model.eval()
for param in sr_model.model.parameters():
param.requires_grad = False
# Load HardConstraint
if harconstraint_mode:
hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
hard_constraint = HardConstraint(low_pass_mask=hard_constraint_weights["weights"].to(device), device=device)
return srmodel(sr_model, hard_constraint, device)
else:
return sr_model
def display_results(path: pathlib.Path, device: str = "cuda:0", harconstraint_mode: bool = True, *args, **kwargs):
# Load model
model = compiled_model(path=path, device=device, harconstraint_mode=harconstraint_mode)
# Load data
lr, hr = example_data(path=path, device=device)
# Run model
SuperX = model(lr)
#Display results
Xrgb = lr[0, 0:3].cpu().numpy().transpose(1, 2, 0)
SuperXrgb = SuperX[0, 0:3].cpu().numpy().transpose(1, 2, 0)
lr_slice = slice(16, 32+80)
hr_slice = slice(lr_slice.start*4, lr_slice.stop*4)
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(Xrgb[lr_slice, lr_slice]*3)
ax[0].set_title("Sentinel-2")
ax[1].imshow(SuperXrgb[hr_slice, hr_slice]*3)
ax[1].set_title("Super-Resolved")
ax[2].imshow(hr[0, 0:3].cpu().numpy().transpose(1, 2, 0)[hr_slice, hr_slice]*3)
ax[2].set_title("True HR")
for a in ax:
a.axis("off")
fig.tight_layout()
return fig