|
import glob |
|
import os |
|
from copy import deepcopy |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import PIL |
|
import spaces |
|
import torch |
|
import yaml |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from safetensors.torch import load_file |
|
from torchvision.transforms import ToPILImage, ToTensor |
|
from transformers import AutoModelForImageSegmentation |
|
from utils import extract_object, get_model_from_config, resize_and_center_crop |
|
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
ASPECT_RATIOS = { |
|
str(512 / 2048): (512, 2048), |
|
str(1024 / 1024): (1024, 1024), |
|
str(2048 / 512): (2048, 512), |
|
str(896 / 1152): (896, 1152), |
|
str(1152 / 896): (1152, 896), |
|
str(512 / 1920): (512, 1920), |
|
str(640 / 1536): (640, 1536), |
|
str(768 / 1280): (768, 1280), |
|
str(1280 / 768): (1280, 768), |
|
str(1536 / 640): (1536, 640), |
|
str(1920 / 512): (1920, 512), |
|
} |
|
|
|
MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token) |
|
CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token) |
|
|
|
with open(CONFIG_PATH, "r") as f: |
|
config = yaml.safe_load(f) |
|
model = get_model_from_config(**config) |
|
sd = load_file(MODEL_PATH) |
|
model.load_state_dict(sd, strict=True) |
|
model.to("cuda").to(torch.bfloat16) |
|
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda() |
|
image_size = (1024, 1024) |
|
|
|
@spaces.GPU |
|
def evaluate(fg_image: PIL.Image.Image, bg_image: PIL.Image.Image, num_sampling_steps: int = 4): |
|
ori_h_bg, ori_w_bg = fg_image.size |
|
ar_bg = ori_h_bg / ori_w_bg |
|
closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg)) |
|
dimensions_bg = ASPECT_RATIOS[closest_ar_bg] |
|
|
|
_, fg_mask = extract_object(birefnet, deepcopy(fg_image)) |
|
|
|
fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1]) |
|
fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1]) |
|
bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1]) |
|
|
|
img_pasted = Image.composite(fg_image, bg_image, fg_mask) |
|
img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1 |
|
batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)} |
|
|
|
z_source = model.vae.encode(batch[model.source_key]) |
|
output_image = model.sample(z=z_source, num_steps=num_sampling_steps, conditioner_inputs=batch, max_samples=1).clamp(-1, 1) |
|
output_image = (output_image[0].float().cpu() + 1) / 2 |
|
output_image = ToPILImage()(output_image) |
|
output_image = Image.composite(output_image, bg_image, fg_mask) |
|
output_image.resize((ori_h_bg, ori_w_bg)) |
|
|
|
return (np.array(img_pasted), np.array(output_image)) |
|
|
|
with gr.Blocks() as app: |
|
gr.HTML(""" |
|
<style> |
|
body::before { |
|
content: ""; |
|
display: block; |
|
height: 320px; |
|
background-color: var(--body-background-fill); |
|
} |
|
button[aria-label="Fullscreen"], button[aria-label="Fullscreen"]:hover { |
|
display: none !important; |
|
visibility: hidden !important; |
|
opacity: 0 !important; |
|
pointer-events: none !important; |
|
} |
|
button[aria-label="Share"], button[aria-label="Share"]:hover { |
|
display: none !important; |
|
} |
|
button[aria-label="Download"] { |
|
transform: scale(3); |
|
transform-origin: top right; |
|
margin: 0 !important; |
|
padding: 6px !important; |
|
} |
|
</style> |
|
""") |
|
|
|
gr.Markdown("# Ndrysho Sfondin") |
|
gr.Markdown("Zëvendëso sfondin e fotove me rindriçim të avancuar nga inteligjenca artificiale.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360) |
|
bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360) |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Rindriço") |
|
with gr.Row(): |
|
num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False) |
|
|
|
bg_gallery = gr.Gallery(object_fit="contain", visible=False) |
|
|
|
with gr.Column(): |
|
output_slider = gr.ImageSlider(label="Para / Pas", type="numpy") |
|
output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider]) |
|
|
|
submit_button.click(evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False) |
|
|
|
def bg_gallery_selected(gal, evt: gr.SelectData): |
|
return gal[evt.index][0] |
|
|
|
bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image) |
|
|
|
if __name__ == "__main__": |
|
app.launch(share=True) |