SeeSharp

Real-time video super-resolution (x4) using a teacher model with multi-branch dilated convolutions and feature alignment. Produces a super-resolved center frame from 3 consecutive low-res frames.

Model summary

  • Task: Video Super-Resolution (VSR), 4ร— upscale
  • Input: 3 frames (previous, current, next), RGB in [0,1], shape (B, 3, 3, H, W)
  • Output: Super-resolved center frame, RGB in [0,1], shape (B, 3, 4H, 4W)
  • Backbone: Feature alignment + SR network with subpixel upsampling (ESPCN-style)
  • Key blocks: Multi-Branch Dilated Convolution (MBD), UpsamplingBlock (PixelShuffle)

Architecture

  • FeatureAlignmentBlock: initial conv stack + MBDModule to aggregate multi-dilation context
  • SRNetwork: deep conv stack + PixelShuffle upsampling + residual add with bicubic upsample of center frame
  • Residual path: bicubic(x_center) added to network output

Intended uses & limitations

  • Use for: Upscaling videos or frame triplets where temporal adjacency exists.
  • Not ideal for: Single images without approximating triplets; domains far from training distribution.
  • Performance: Teacher is heavier than student; better visual quality, slower on CPU.

Quick start (inference)

Clone this repo or ensure the model files ersvr/models/*.py are available locally.

import torch, sys
from huggingface_hub import hf_hub_download

# If you cloned the model repo contents locally:
# sys.path.append(".")

from ersvr.models.ersvr import ERSVR
import numpy as np

# Download weights
ckpt_path = hf_hub_download(
    repo_id="Abhinavexists/SeeSharp",
    filename="weights/ersvr_best.pth"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = ERSVR(scale_factor=4).to(device)

state = torch.load(ckpt_path, map_location=device)
if isinstance(state, dict) and "model_state_dict" in state:
    state = state["model_state_dict"]
model.load_state_dict(state)
model.eval()

# Prepare a triplet: (3, H, W, 3) with values in [0,1]
img = np.random.rand(128, 128, 3).astype("float32")
triplet = np.stack([img, img, img], axis=0)  # demo: same frame
tensor = torch.from_numpy(triplet).permute(3,0,1,2).unsqueeze(0).to(device)  # (1,3,3,H,W)

with torch.no_grad():
    out = model(tensor).clamp(0,1)  # (1,3,4H,4W)

I/O details

  • Normalization: expects [0,1] floats; convert from uint8 with img.astype(np.float32)/255.0
  • Center frame: residual uses bicubic upsampling of middle frame
  • Temporal window: exactly 3 frames

Weights

  • weights/ersvr_best.pth (recommended)
  • weights/ersvr_epoch_10.pth, weights/ersvr_epoch_20.pth, weights/ersvr_epoch_30.pth (training checkpoints)

Metrics

  • Report typical VSR metrics:
    • PSNR: 34.2 dB
    • SSIM: 0.94

Training

  • 4ร— upscale, triplet-based supervision.
  • See training utilities in ersvr/train.py for metric computation helpers.

License

  • MIT
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support