File size: 4,747 Bytes
a88bb44 4f9fb0d a88bb44 130940f a88bb44 6e5acc2 a88bb44 4f9fb0d a88bb44 cd5da5d 130940f a88bb44 1176392 130940f a88bb44 130940f a88bb44 1176392 a88bb44 1176392 a88bb44 77779dc ba62ed4 a88bb44 1176392 a88bb44 ba66b21 130940f a88bb44 130940f a88bb44 3d9d07f 130940f a88bb44 130940f a88bb44 77779dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import glob
import os
from copy import deepcopy
import gradio as gr
import numpy as np
import PIL
import spaces
import torch
import yaml
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import ToPILImage, ToTensor
from transformers import AutoModelForImageSegmentation
from utils import extract_object, get_model_from_config, resize_and_center_crop
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
ASPECT_RATIOS = {
str(512 / 2048): (512, 2048),
str(1024 / 1024): (1024, 1024),
str(2048 / 512): (2048, 512),
str(896 / 1152): (896, 1152),
str(1152 / 896): (1152, 896),
str(512 / 1920): (512, 1920),
str(640 / 1536): (640, 1536),
str(768 / 1280): (768, 1280),
str(1280 / 768): (1280, 768),
str(1536 / 640): (1536, 640),
str(1920 / 512): (1920, 512),
}
MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token)
CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token)
with open(CONFIG_PATH, "r") as f:
config = yaml.safe_load(f)
model = get_model_from_config(**config)
sd = load_file(MODEL_PATH)
model.load_state_dict(sd, strict=True)
model.to("cuda").to(torch.bfloat16)
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda()
image_size = (1024, 1024)
@spaces.GPU
def evaluate(fg_image: PIL.Image.Image, bg_image: PIL.Image.Image, num_sampling_steps: int = 4):
ori_h_bg, ori_w_bg = fg_image.size
ar_bg = ori_h_bg / ori_w_bg
closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
dimensions_bg = ASPECT_RATIOS[closest_ar_bg]
_, fg_mask = extract_object(birefnet, deepcopy(fg_image))
fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1])
fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1])
bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
img_pasted = Image.composite(fg_image, bg_image, fg_mask)
img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)}
z_source = model.vae.encode(batch[model.source_key])
output_image = model.sample(z=z_source, num_steps=num_sampling_steps, conditioner_inputs=batch, max_samples=1).clamp(-1, 1)
output_image = (output_image[0].float().cpu() + 1) / 2
output_image = ToPILImage()(output_image)
output_image = Image.composite(output_image, bg_image, fg_mask)
output_image.resize((ori_h_bg, ori_w_bg))
return (np.array(img_pasted), np.array(output_image))
with gr.Blocks() as app:
gr.HTML("""
<style>
body::before {
content: "";
display: block;
height: 320px;
background-color: var(--body-background-fill);
}
button[aria-label="Fullscreen"], button[aria-label="Fullscreen"]:hover {
display: none !important;
visibility: hidden !important;
opacity: 0 !important;
pointer-events: none !important;
}
button[aria-label="Share"], button[aria-label="Share"]:hover {
display: none !important;
}
button[aria-label="Download"] {
transform: scale(3);
transform-origin: top right;
margin: 0 !important;
padding: 6px !important;
}
</style>
""")
gr.Markdown("# Ndrysho Sfondin")
gr.Markdown("Zëvendëso sfondin e fotove me rindriçim të avancuar nga inteligjenca artificiale.")
with gr.Row():
with gr.Column():
with gr.Row():
fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360)
bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360)
with gr.Row():
submit_button = gr.Button("Rindriço")
with gr.Row():
num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False)
bg_gallery = gr.Gallery(object_fit="contain", visible=False)
with gr.Column():
output_slider = gr.ImageSlider(label="Para / Pas", type="numpy")
output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider])
submit_button.click(evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False)
def bg_gallery_selected(gal, evt: gr.SelectData):
return gal[evt.index][0]
bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
if __name__ == "__main__":
app.launch(share=True) |