TempleX commited on
Commit
949981c
·
1 Parent(s): 0aae67d

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+ import spaces
5
+ from diffusers import StableDiffusionXLPipeline
6
+ from transformers import AutoFeatureExtractor
7
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
8
+ from ip_adapter.ip_adapter import EasyRef
9
+ from huggingface_hub import hf_hub_download
10
+ import gradio as gr
11
+ import cv2
12
+ import pillow_avif
13
+
14
+ def adaptive_resize(w, h, size=1024):
15
+ times = math.sqrt(h * w / (size**2))
16
+ if w==h:
17
+ w, h = size, size
18
+ elif times > 1.1:
19
+ w, h = math.ceil(w / times), math.ceil(h / times)
20
+ elif times < 0.8:
21
+ w, h = math.ceil(w / times), math.ceil(h / times)
22
+ new_w, new_h = 64 * (math.ceil(w / 64)), 64 * (math.ceil(h / 64))
23
+ return new_w, new_h
24
+
25
+ def res2string(w, h):
26
+ return str(w)+"x"+str(h)
27
+
28
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
29
+ multimodal_llm_path = "Qwen/Qwen2-VL-2B-Instruct"
30
+ ip_ckpt = hf_hub_download(repo_id="zongzhuofan/EasyRef", filename="pytorch_model.bin", repo_type="model")
31
+
32
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
33
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
34
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
35
+
36
+ device = "cuda"
37
+
38
+ pipe = StableDiffusionXLPipeline.from_pretrained(
39
+ base_model_path,
40
+ torch_dtype=torch.float16,
41
+ feature_extractor=safety_feature_extractor,
42
+ safety_checker=safety_checker,
43
+ add_watermarker=False,
44
+ ).to(device)
45
+
46
+ easyref = EasyRef(pipe, multimodal_llm_path, ip_ckpt, device, num_tokens=64, use_lora=True, lora_rank=128)
47
+
48
+ cv2.setNumThreads(1)
49
+
50
+ @spaces.GPU(enable_queue=True)
51
+ def generate_image(images, prompt, negative_prompt, height, width, scale, num_inference_steps, seed, progress=gr.Progress(track_tqdm=True)):
52
+ print("Generating")
53
+ new_width, new_height = adaptive_resize(width, height)
54
+ print("The output resolution is adaptively resized from {ori} to {new}".format(ori=res2string(width, height), new=res2string(new_width, new_height)))
55
+ template = "Visualize a scene that closely resembles the provided images, capturing the essence and details described in this prompt:\n"
56
+ system_prompt = [template + prompt, template]
57
+ image = easyref.generate(
58
+ pil_image=images,
59
+ system_prompt=system_prompt,
60
+ prompt=prompt,
61
+ negative_prompt=negative_prompt,
62
+ num_samples=1,
63
+ num_inference_steps=num_inference_steps,
64
+ seed=seed,
65
+ height=new_height,
66
+ width=new_width)
67
+ print(image)
68
+ return image
69
+
70
+ # def change_style(style):
71
+ # if style == "Photorealistic":
72
+ # return(gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0))
73
+ # else:
74
+ # return(gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8))
75
+
76
+ def swap_to_gallery(images):
77
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
78
+
79
+ def remove_back_to_files():
80
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
81
+
82
+ MAX_SEED = np.iinfo(np.int32).max
83
+ css = '''
84
+ h1{margin-bottom: 0 !important}
85
+ '''
86
+ with gr.Blocks(css=css) as demo:
87
+ gr.Markdown("# EasyRef demo")
88
+ gr.Markdown("Demo for the [zongzhuofan/EasyRef model](https://huggingface.co/zongzhuofan/EasyRef)")
89
+ with gr.Row():
90
+ with gr.Column():
91
+ files = gr.Files(
92
+ label="Upload several reference images",
93
+ file_types=["image"]
94
+ )
95
+ uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
96
+ with gr.Column(visible=False) as clear_button:
97
+ remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
98
+ prompt = gr.Textbox(label="Prompt",
99
+ placeholder="An oil painting of a [man/woman/person]...")
100
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality")
101
+ # style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic")
102
+ submit = gr.Button("Submit")
103
+ with gr.Accordion(open=False, label="Advanced Options"):
104
+ height = gr.Slider(label="Height", info="This will be adaptively resized", value=1024, step=64, minimum=896, maximum=1280)
105
+ width = gr.Slider(label="Width", info="This will be adaptively resized", value=1024, step=64, minimum=896, maximum=1280)
106
+ scale = gr.Slider(label="Scale", info="Scale for reference images", value=1.0, step=0.1, minimum=0.5, maximum=1.5)
107
+ num_inference_steps = gr.Slider(label="Number of inference steps", value=30, step=1, minimum=1, maximum=60)
108
+ seed = gr.Slider(label="Seed", value=24, step=1, minimum=0, maximum=MAX_SEED)
109
+ with gr.Column():
110
+ gallery = gr.Gallery(label="Generated Images")
111
+ # style.change(fn=change_style,
112
+ # inputs=style,
113
+ # outputs=[preserve, face_strength, likeness_strength])
114
+ files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
115
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
116
+ submit.click(fn=generate_image,
117
+ inputs=[files,prompt,negative_prompt,height, width, scale, num_inference_steps, seed],
118
+ outputs=gallery)
119
+ examples = gr.Examples(
120
+ examples=[
121
+ [
122
+ [
123
+ "assets/aragaki_identity/1.jpg",
124
+ "assets/aragaki_identity/2.webp",
125
+ "assets/aragaki_identity/3.webp",
126
+ "assets/aragaki_identity/4.jpeg",
127
+ "assets/aragaki_identity/5.webp",
128
+ ],
129
+ "An oil painting of a smiling woman.",
130
+ ],
131
+ [
132
+ [
133
+ "assets/blindbox_style/1.jpg",
134
+ "assets/blindbox_style/2.jpg",
135
+ "assets/blindbox_style/3.jpg",
136
+ "assets/blindbox_style/4.jpg",
137
+ "assets/blindbox_style/5.jpg",
138
+ ],
139
+ "Donald Trump"
140
+ ],
141
+ ],
142
+ inputs=[files, prompt],
143
+ )
144
+
145
+ gr.Markdown("We release our checkpoints for research purposes only. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users.")
146
+
147
+ demo.launch()
assets/aragaki_identity/1.jpg ADDED
assets/aragaki_identity/2.webp ADDED
assets/aragaki_identity/3.webp ADDED
assets/aragaki_identity/4.jpeg ADDED
assets/aragaki_identity/5.webp ADDED
assets/blindbox_style/1.jpg ADDED
assets/blindbox_style/2.jpg ADDED
assets/blindbox_style/3.jpg ADDED
assets/blindbox_style/4.jpg ADDED
assets/blindbox_style/5.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.22.1
2
+ torch==2.3.1
3
+ torchvision==0.18.1
4
+ transformers==4.45.2
5
+ qwen_vl_utils==0.0.8
6
+ accelerate
7
+ safetensors
8
+ einops
9
+ omegaconf
10
+ peft
11
+ pillow-avif-plugin
12
+ huggingface-hub==0.23.3
13
+ opencv-python
14
+ git+https://github.com/TempleX98/EasyRef.git