from transformers import AutoModel | |
import torch | |
from PIL import Image | |
import os | |
from torchvision import transforms | |
# Change working directory to the script’s folder | |
os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the model | |
model = AutoModel.from_pretrained("vopeai/claris_rf_channel", trust_remote_code=True) | |
model.to(device) | |
model.eval() | |
# Load input + reference frames | |
input_img = Image.open("sample_img.png").convert("RGB") | |
ref_img = Image.open("ref_img.png").convert("RGB") | |
# Inference | |
with torch.no_grad(): | |
output = model(input_img, ref_img) | |
# Convert to PIL and save | |
output_pil = transforms.ToPILImage()(output.cpu()) | |
output_pil.save("output_img_rfchannel.png") | |
print("Saved output as 'output_img_rfchannel.png'") |