Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import math | |
import spaces | |
from diffusers import StableDiffusionXLPipeline | |
from transformers import AutoFeatureExtractor | |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
from ip_adapter import EasyRef | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
import os | |
import cv2 | |
import pillow_avif | |
def adaptive_resize(w, h, size=1024): | |
times = math.sqrt(h * w / (size**2)) | |
if w==h: | |
w, h = size, size | |
elif times > 1.1: | |
w, h = math.ceil(w / times), math.ceil(h / times) | |
elif times < 0.8: | |
w, h = math.ceil(w / times), math.ceil(h / times) | |
new_w, new_h = 64 * (math.ceil(w / 64)), 64 * (math.ceil(h / 64)) | |
return new_w, new_h | |
def res2string(w, h): | |
return str(w)+"x"+str(h) | |
def get_image_path_list(folder_name): | |
image_basename_list = os.listdir(folder_name) | |
image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list]) | |
return image_path_list | |
def get_example(): | |
case = [ | |
[ | |
get_image_path_list('./assets/aragaki_identity'), | |
"An oil painting of a smiling woman.", | |
"A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality", | |
], | |
[ | |
get_image_path_list('./assets/blindbox_style'), | |
"Donald Trump", | |
"A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality", | |
], | |
] | |
return case | |
def upload_example_to_gallery(images, prompt, negative_prompt): | |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
multimodal_llm_path = "Qwen/Qwen2-VL-2B-Instruct" | |
ip_ckpt = hf_hub_download(repo_id="zongzhuofan/EasyRef", filename="pytorch_model.bin", repo_type="model") | |
safety_model_id = "CompVis/stable-diffusion-safety-checker" | |
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) | |
device = "cuda" | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
base_model_path, | |
torch_dtype=torch.float16, | |
feature_extractor=safety_feature_extractor, | |
safety_checker=safety_checker, | |
add_watermarker=False, | |
).to(device) | |
easyref = EasyRef(pipe, multimodal_llm_path, ip_ckpt, device, num_tokens=64, use_lora=True, lora_rank=128) | |
cv2.setNumThreads(1) | |
def generate_image(images, prompt, negative_prompt, scale, num_inference_steps, seed, progress=gr.Progress(track_tqdm=True)): | |
print("Generating") | |
template = "Visualize a scene that closely resembles the provided images, capturing the essence and details described in this prompt:\n" | |
system_prompt = [template + prompt, template] | |
image = easyref.generate( | |
pil_image=images, | |
system_prompt=system_prompt, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
scale=scale, | |
num_samples=1, | |
num_inference_steps=num_inference_steps, | |
seed=seed) | |
print(image) | |
return image | |
# def change_style(style): | |
# if style == "Photorealistic": | |
# return(gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)) | |
# else: | |
# return(gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)) | |
def swap_to_gallery(images): | |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
def remove_back_to_files(): | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
MAX_SEED = np.iinfo(np.int32).max | |
css = ''' | |
h1{margin-bottom: 0 !important} | |
''' | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# EasyRef demo") | |
gr.Markdown("Demo for the [zongzhuofan/EasyRef model](https://huggingface.co/zongzhuofan/EasyRef)") | |
with gr.Row(): | |
with gr.Column(): | |
files = gr.Files( | |
label="Multiple reference images", | |
file_types=["image"] | |
) | |
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=6, rows=1, height=125) | |
with gr.Column(visible=False) as clear_button: | |
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm") | |
prompt = gr.Textbox(label="Prompt", | |
placeholder="An oil painting of a [man/woman/person]...") | |
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality") | |
# style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic") | |
submit = gr.Button("Submit") | |
with gr.Accordion(open=False, label="Advanced Options"): | |
scale = gr.Slider(label="Scale", info="Scale for image reference", value=1.0, step=0.1, minimum=0.5, maximum=1.5) | |
num_inference_steps = gr.Slider(label="Number of inference steps", value=30, step=1, minimum=1, maximum=60) | |
seed = gr.Slider(label="Seed", value=24, step=1, minimum=0, maximum=MAX_SEED) | |
with gr.Column(): | |
gallery = gr.Gallery(label="Generated Images") | |
gr.Examples( | |
examples=get_example(), | |
inputs=[files, prompt, negative_prompt], | |
run_on_click=True, | |
fn=upload_example_to_gallery, | |
outputs=[uploaded_files, clear_button, files], | |
) | |
# style.change(fn=change_style, | |
# inputs=style, | |
# outputs=[preserve, face_strength, likeness_strength]) | |
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files]) | |
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files]) | |
submit.click(fn=generate_image, | |
inputs=[files, prompt, negative_prompt, scale, num_inference_steps, seed], | |
outputs=gallery) | |
gr.Markdown("We release our checkpoints for research purposes only. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users.") | |
demo.launch() |