EasyRef / app.py
TempleX
Update README
2283c6e
import torch
import numpy as np
import math
import spaces
from diffusers import StableDiffusionXLPipeline
from transformers import AutoFeatureExtractor
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ip_adapter import EasyRef
from huggingface_hub import hf_hub_download
import gradio as gr
import os
import cv2
import pillow_avif
def adaptive_resize(w, h, size=1024):
times = math.sqrt(h * w / (size**2))
if w==h:
w, h = size, size
elif times > 1.1:
w, h = math.ceil(w / times), math.ceil(h / times)
elif times < 0.8:
w, h = math.ceil(w / times), math.ceil(h / times)
new_w, new_h = 64 * (math.ceil(w / 64)), 64 * (math.ceil(h / 64))
return new_w, new_h
def res2string(w, h):
return str(w)+"x"+str(h)
def get_image_path_list(folder_name):
image_basename_list = os.listdir(folder_name)
image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list])
return image_path_list
def get_example():
case = [
[
get_image_path_list('./assets/aragaki_identity'),
"An oil painting of a smiling woman.",
"A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality",
],
[
get_image_path_list('./assets/blindbox_style'),
"Donald Trump",
"A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality",
],
]
return case
def upload_example_to_gallery(images, prompt, negative_prompt):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
multimodal_llm_path = "Qwen/Qwen2-VL-2B-Instruct"
ip_ckpt = hf_hub_download(repo_id="zongzhuofan/EasyRef", filename="pytorch_model.bin", repo_type="model")
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
device = "cuda"
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
feature_extractor=safety_feature_extractor,
safety_checker=safety_checker,
add_watermarker=False,
).to(device)
easyref = EasyRef(pipe, multimodal_llm_path, ip_ckpt, device, num_tokens=64, use_lora=True, lora_rank=128)
cv2.setNumThreads(1)
@spaces.GPU(enable_queue=True)
def generate_image(images, prompt, negative_prompt, scale, num_inference_steps, seed, progress=gr.Progress(track_tqdm=True)):
print("Generating")
template = "Visualize a scene that closely resembles the provided images, capturing the essence and details described in this prompt:\n"
system_prompt = [template + prompt, template]
image = easyref.generate(
pil_image=images,
system_prompt=system_prompt,
prompt=prompt,
negative_prompt=negative_prompt,
scale=scale,
num_samples=1,
num_inference_steps=num_inference_steps,
seed=seed)
print(image)
return image
# def change_style(style):
# if style == "Photorealistic":
# return(gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0))
# else:
# return(gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8))
def swap_to_gallery(images):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
def remove_back_to_files():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
MAX_SEED = np.iinfo(np.int32).max
css = '''
h1{margin-bottom: 0 !important}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# EasyRef demo")
gr.Markdown("Demo for the [zongzhuofan/EasyRef model](https://huggingface.co/zongzhuofan/EasyRef)")
with gr.Row():
with gr.Column():
files = gr.Files(
label="Multiple reference images",
file_types=["image"]
)
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=6, rows=1, height=125)
with gr.Column(visible=False) as clear_button:
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
prompt = gr.Textbox(label="Prompt",
placeholder="An oil painting of a [man/woman/person]...")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality")
# style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic")
submit = gr.Button("Submit")
with gr.Accordion(open=False, label="Advanced Options"):
scale = gr.Slider(label="Scale", info="Scale for image reference", value=1.0, step=0.1, minimum=0.5, maximum=1.5)
num_inference_steps = gr.Slider(label="Number of inference steps", value=30, step=1, minimum=1, maximum=60)
seed = gr.Slider(label="Seed", value=24, step=1, minimum=0, maximum=MAX_SEED)
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
gr.Examples(
examples=get_example(),
inputs=[files, prompt, negative_prompt],
run_on_click=True,
fn=upload_example_to_gallery,
outputs=[uploaded_files, clear_button, files],
)
# style.change(fn=change_style,
# inputs=style,
# outputs=[preserve, face_strength, likeness_strength])
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
submit.click(fn=generate_image,
inputs=[files, prompt, negative_prompt, scale, num_inference_steps, seed],
outputs=gallery)
gr.Markdown("We release our checkpoints for research purposes only. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users.")
demo.launch()