HAL1993's picture
Update app.py
ba62ed4 verified
import glob
import os
from copy import deepcopy
import gradio as gr
import numpy as np
import PIL
import spaces
import torch
import yaml
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import ToPILImage, ToTensor
from transformers import AutoModelForImageSegmentation
from utils import extract_object, get_model_from_config, resize_and_center_crop
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
ASPECT_RATIOS = {
str(512 / 2048): (512, 2048),
str(1024 / 1024): (1024, 1024),
str(2048 / 512): (2048, 512),
str(896 / 1152): (896, 1152),
str(1152 / 896): (1152, 896),
str(512 / 1920): (512, 1920),
str(640 / 1536): (640, 1536),
str(768 / 1280): (768, 1280),
str(1280 / 768): (1280, 768),
str(1536 / 640): (1536, 640),
str(1920 / 512): (1920, 512),
}
MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token)
CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token)
with open(CONFIG_PATH, "r") as f:
config = yaml.safe_load(f)
model = get_model_from_config(**config)
sd = load_file(MODEL_PATH)
model.load_state_dict(sd, strict=True)
model.to("cuda").to(torch.bfloat16)
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda()
image_size = (1024, 1024)
@spaces.GPU
def evaluate(fg_image: PIL.Image.Image, bg_image: PIL.Image.Image, num_sampling_steps: int = 4):
ori_h_bg, ori_w_bg = fg_image.size
ar_bg = ori_h_bg / ori_w_bg
closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
dimensions_bg = ASPECT_RATIOS[closest_ar_bg]
_, fg_mask = extract_object(birefnet, deepcopy(fg_image))
fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1])
fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1])
bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
img_pasted = Image.composite(fg_image, bg_image, fg_mask)
img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)}
z_source = model.vae.encode(batch[model.source_key])
output_image = model.sample(z=z_source, num_steps=num_sampling_steps, conditioner_inputs=batch, max_samples=1).clamp(-1, 1)
output_image = (output_image[0].float().cpu() + 1) / 2
output_image = ToPILImage()(output_image)
output_image = Image.composite(output_image, bg_image, fg_mask)
output_image.resize((ori_h_bg, ori_w_bg))
return (np.array(img_pasted), np.array(output_image))
with gr.Blocks() as app:
gr.HTML("""
<style>
body::before {
content: "";
display: block;
height: 320px;
background-color: var(--body-background-fill);
}
button[aria-label="Fullscreen"], button[aria-label="Fullscreen"]:hover {
display: none !important;
visibility: hidden !important;
opacity: 0 !important;
pointer-events: none !important;
}
button[aria-label="Share"], button[aria-label="Share"]:hover {
display: none !important;
}
button[aria-label="Download"] {
transform: scale(3);
transform-origin: top right;
margin: 0 !important;
padding: 6px !important;
}
</style>
""")
gr.Markdown("# Ndrysho Sfondin")
gr.Markdown("Zëvendëso sfondin e fotove me rindriçim të avancuar nga inteligjenca artificiale.")
with gr.Row():
with gr.Column():
with gr.Row():
fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360)
bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360)
with gr.Row():
submit_button = gr.Button("Rindriço")
with gr.Row():
num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False)
bg_gallery = gr.Gallery(object_fit="contain", visible=False)
with gr.Column():
output_slider = gr.ImageSlider(label="Para / Pas", type="numpy")
output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider])
submit_button.click(evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False)
def bg_gallery_selected(gal, evt: gr.SelectData):
return gal[evt.index][0]
bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
if __name__ == "__main__":
app.launch(share=True)