|
|
import pathlib |
|
|
import opensr_model |
|
|
from omegaconf import OmegaConf |
|
|
import safetensors.torch |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from sen2sr.models.tricks import HardConstraint |
|
|
from sen2sr.nonreference import srmodel |
|
|
|
|
|
|
|
|
config = { |
|
|
"apply_normalization": False, |
|
|
"ckpt_version": "opensr_10m_v4_v6.ckpt", |
|
|
"encode_conditioning": True, |
|
|
"denoiser_settings": { |
|
|
"linear_start": 0.0015, |
|
|
"linear_end": 0.0155, |
|
|
"timesteps": 1000, |
|
|
"sampling_eta": 1.0, |
|
|
"sampling_steps":200, |
|
|
"sampling_temperature":1., |
|
|
}, |
|
|
"first_stage_config": { |
|
|
"embed_dim": 4, |
|
|
"double_z": True, |
|
|
"z_channels": 4, |
|
|
"resolution": 256, |
|
|
"in_channels": 4, |
|
|
"out_ch": 4, |
|
|
"ch": 128, |
|
|
"ch_mult": [1, 2, 4], |
|
|
"num_res_blocks": 2, |
|
|
"attn_resolutions": [], |
|
|
"dropout": 0.0 |
|
|
}, |
|
|
"cond_stage_config": { |
|
|
"image_size": 64, |
|
|
"in_channels": 8, |
|
|
"model_channels": 160, |
|
|
"out_channels": 4, |
|
|
"num_res_blocks": 2, |
|
|
"attention_resolutions": [16, 8], |
|
|
"channel_mult": [1, 2, 2, 4], |
|
|
"num_head_channels": 32 |
|
|
}, |
|
|
"other": { |
|
|
"concat_mode": True, |
|
|
"cond_stage_trainable": False, |
|
|
"first_stage_key": "image", |
|
|
"cond_stage_key": "LR_image" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def example_data(path: pathlib.Path, device: str = "cuda:0", *args, **kwargs): |
|
|
data_f = path / "example_data.safetensor" |
|
|
sample = safetensors.torch.load_file(data_f) |
|
|
return sample["lr"].to(device), sample["hr"].to(device) |
|
|
|
|
|
def trainable_model(*args, **kwargs): |
|
|
print("Trainable model is not available for this model.") |
|
|
return None |
|
|
|
|
|
def compiled_model(path, device: str = "cuda:0", harconstraint_mode: bool = True, *args, **kwargs): |
|
|
trainable_f = path / "model.safetensors" |
|
|
|
|
|
|
|
|
sr_model_weights = safetensors.torch.load_file(trainable_f) |
|
|
sr_model = opensr_model.SRLatentDiffusion(config=OmegaConf.create(config), device=device) |
|
|
sr_model.model.load_state_dict(sr_model_weights, strict=True) |
|
|
sr_model.model.eval() |
|
|
for param in sr_model.model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
if harconstraint_mode: |
|
|
hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor") |
|
|
hard_constraint = HardConstraint(low_pass_mask=hard_constraint_weights["weights"].to(device), device=device) |
|
|
return srmodel(sr_model, hard_constraint, device) |
|
|
else: |
|
|
return sr_model |
|
|
|
|
|
|
|
|
def display_results(path: pathlib.Path, device: str = "cuda:0", harconstraint_mode: bool = True, *args, **kwargs): |
|
|
|
|
|
model = compiled_model(path=path, device=device, harconstraint_mode=harconstraint_mode) |
|
|
|
|
|
|
|
|
lr, hr = example_data(path=path, device=device) |
|
|
|
|
|
|
|
|
SuperX = model(lr) |
|
|
|
|
|
|
|
|
Xrgb = lr[0, 0:3].cpu().numpy().transpose(1, 2, 0) |
|
|
SuperXrgb = SuperX[0, 0:3].cpu().numpy().transpose(1, 2, 0) |
|
|
lr_slice = slice(16, 32+80) |
|
|
hr_slice = slice(lr_slice.start*4, lr_slice.stop*4) |
|
|
fig, ax = plt.subplots(1, 3, figsize=(12, 4)) |
|
|
ax[0].imshow(Xrgb[lr_slice, lr_slice]*3) |
|
|
ax[0].set_title("Sentinel-2") |
|
|
ax[1].imshow(SuperXrgb[hr_slice, hr_slice]*3) |
|
|
ax[1].set_title("Super-Resolved") |
|
|
ax[2].imshow(hr[0, 0:3].cpu().numpy().transpose(1, 2, 0)[hr_slice, hr_slice]*3) |
|
|
ax[2].set_title("True HR") |
|
|
for a in ax: |
|
|
a.axis("off") |
|
|
fig.tight_layout() |
|
|
return fig |
|
|
|