claris-RF-channel / inference_example.py
noeedc
Add README and update inference example for clarity and usage instructions
31b5c87
from transformers import AutoModel
import torch
from PIL import Image
import os
from torchvision import transforms
# Change working directory to the script’s folder
os.chdir(os.path.dirname(os.path.abspath(__file__)))
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = AutoModel.from_pretrained("vopeai/claris_rf_channel", trust_remote_code=True)
model.to(device)
model.eval()
# Load input + reference frames
input_img = Image.open("sample_img.png").convert("RGB")
ref_img = Image.open("ref_img.png").convert("RGB")
# Inference
with torch.no_grad():
output = model(input_img, ref_img)
# Convert to PIL and save
output_pil = transforms.ToPILImage()(output.cpu())
output_pil.save("output_img_rfchannel.png")
print("Saved output as 'output_img_rfchannel.png'")