Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,131 +1,112 @@
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
2 |
from huggingface_hub import snapshot_download
|
3 |
-
|
|
|
|
|
4 |
|
5 |
-
#
|
|
|
6 |
REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
|
7 |
-
LOCAL_FLORENCE = snapshot_download(
|
8 |
-
repo_id="microsoft/Florence-2-base",
|
9 |
-
revision=REVISION
|
10 |
-
)
|
11 |
-
LOCAL_TURBOX = snapshot_download(
|
12 |
-
repo_id="tensorart/stable-diffusion-3.5-large-TurboX"
|
13 |
-
)
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
mod = types.ModuleType('flash_attn')
|
19 |
-
mod.__spec__ = spec
|
20 |
-
sys.modules['flash_attn'] = mod
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
from PIL import Image
|
26 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
27 |
-
from transformers import ( CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor)
|
28 |
-
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
-
model_repo,
|
34 |
-
torch_dtype=torch.float16,
|
35 |
-
load_in_8bit=True,
|
36 |
-
device_map="auto",
|
37 |
-
safety_checker=None,
|
38 |
-
feature_extractor=None
|
39 |
-
)
|
40 |
-
pipe = pipe.to("cuda") # GPU ์ฌ์ฉ
|
41 |
-
# ์ค์ผ์ค๋ฌ ๋ก๋ (Euler ๋ฐฉ์)
|
42 |
-
pipe.scheduler = EulerDiscreteScheduler.from_pretrained(
|
43 |
-
model_repo, subfolder="scheduler", local_files_only=True
|
44 |
)
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
# Florence ๋ชจ๋ธ ์ค์ (CPU ๋ก๋ ํ ํ์ ์ GPU ์ด๋)
|
47 |
florence_model = AutoModelForCausalLM.from_pretrained(
|
48 |
-
LOCAL_FLORENCE,
|
49 |
-
trust_remote_code=True,
|
50 |
-
torch_dtype=torch.float16,
|
51 |
-
load_in_8bit=True # ํ
์คํธ ์์ฑ ๋ชจ๋ธ๋ 8bit ๋ก๋
|
52 |
)
|
53 |
-
florence_model
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
58 |
)
|
|
|
|
|
|
|
59 |
|
60 |
-
#
|
61 |
MAX_SEED = 2**31 - 1
|
62 |
|
63 |
-
|
64 |
-
"""
|
65 |
-
์
๋ ฅ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ 512x512๋ก ๋ฆฌ์ฌ์ด์ง
|
66 |
-
"""
|
67 |
-
img = input_image.convert("RGB")
|
68 |
-
img = img.resize((512, 512), resample=Image.LANCZOS) # ๊ณ ํ์ง ๋ฆฌ์ฌ์ด์ง
|
69 |
-
return img
|
70 |
-
|
71 |
-
# ์์ด ์ค๋ช
์ ์นดํฐ ์คํ์ผ ํ๊ตญ์ด ํ๋กฌํํธ๋ก ๋ณํ
|
72 |
def pseudo_translate_to_korean_style(en_prompt: str) -> str:
|
73 |
-
return f"
|
74 |
|
75 |
-
#
|
76 |
def generate_prompt(image):
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
88 |
)
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
return pseudo_translate_to_korean_style(generated_text)
|
93 |
|
94 |
-
#
|
95 |
def generate_image(prompt, seed=42, randomize_seed=False):
|
96 |
-
# ์๋ ๋๋คํ ์ต์
|
97 |
if randomize_seed:
|
98 |
seed = random.randint(0, MAX_SEED)
|
99 |
-
generator = torch.Generator(
|
100 |
-
# ํด์๋๋ฅผ 512x512๋ก ์ค์ ํด ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
|
101 |
image = pipe(
|
102 |
prompt=prompt,
|
103 |
-
guidance_scale=1.
|
104 |
-
num_inference_steps=
|
105 |
width=512,
|
106 |
height=512,
|
107 |
generator=generator
|
108 |
).images[0]
|
109 |
return image, seed
|
110 |
|
111 |
-
# Gradio UI
|
112 |
with gr.Blocks() as demo:
|
113 |
-
gr.Markdown("# ๐ผ ์ด๋ฏธ์ง โ ์ค๋ช
โ ์นดํฐ ์ด๋ฏธ์ง ์์ฑ๊ธฐ")
|
114 |
-
gr.Markdown(
|
115 |
-
|
116 |
-
|
117 |
-
"2. AI๊ฐ ์์ธ ์ค๋ช
์ ์์ฑํ๊ณ ์นดํฐ ์คํ์ผ ํ๊ตญ์ด ํ๋กฌํํธ๋ก ๋ณํํฉ๋๋ค.\n"
|
118 |
-
"3. ์ค๋ฅธ์ชฝ์ ์์ฑ๋ ์นดํฐ ์ด๋ฏธ์ง๋ฅผ ํ์ธํ์ธ์."
|
119 |
-
)
|
120 |
with gr.Row():
|
121 |
with gr.Column():
|
122 |
-
input_img = gr.Image(
|
123 |
-
run_button = gr.Button("์์ฑ ์์")
|
|
|
124 |
with gr.Column():
|
125 |
-
prompt_out = gr.Textbox(label="
|
126 |
-
output_img = gr.Image(label="์์ฑ๋
|
127 |
|
128 |
-
# ๋ฒํผ ํด๋ฆญ ์ ์ ์ฒด ํ๋ก์ธ์ค ์คํ
|
129 |
def full_process(img):
|
130 |
prompt = generate_prompt(img)
|
131 |
image, seed = generate_image(prompt, randomize_seed=True)
|
@@ -133,4 +114,4 @@ with gr.Blocks() as demo:
|
|
133 |
|
134 |
run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
|
135 |
|
136 |
-
demo.launch(
|
|
|
1 |
import os
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import importlib
|
5 |
+
from PIL import Image
|
6 |
from huggingface_hub import snapshot_download
|
7 |
+
import gradio as gr
|
8 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
9 |
+
from diffusers import StableDiffusionPipeline, DiffusionPipeline, EulerDiscreteScheduler, UNet2DConditionModel
|
10 |
|
11 |
+
# ํ๊ฒฝ ์ค์
|
12 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
|
13 |
REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
# ๋ก์ปฌ ๋ค์ด๋ก๋
|
16 |
+
LOCAL_FLORENCE = snapshot_download("microsoft/Florence-2-base", revision=REVISION)
|
17 |
+
LOCAL_TURBOX = snapshot_download("tensorart/stable-diffusion-3.5-large-TurboX")
|
|
|
|
|
|
|
18 |
|
19 |
+
# ๋๋ฐ์ด์ค ๋ฐ dtype ์ค์
|
20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# ๋ชจ๋ธ ๋ก๋ฉ (๋ถ๋ถ๋ณ ๋ก๋ฉ + dtype ์ ์ฉ)
|
24 |
+
scheduler = EulerDiscreteScheduler.from_pretrained(
|
25 |
+
LOCAL_TURBOX, subfolder="scheduler", torch_dtype=dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
)
|
27 |
+
text_encoder = CLIPTextModel.from_pretrained(LOCAL_TURBOX, subfolder="text_encoder", torch_dtype=dtype)
|
28 |
+
tokenizer = CLIPTokenizer.from_pretrained(LOCAL_TURBOX, subfolder="tokenizer")
|
29 |
+
feature_extractor = CLIPFeatureExtractor.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="feature_extractor")
|
30 |
+
unet = UNet2DConditionModel.from_pretrained(LOCAL_TURBOX, subfolder="unet", torch_dtype=dtype)
|
31 |
|
|
|
32 |
florence_model = AutoModelForCausalLM.from_pretrained(
|
33 |
+
LOCAL_FLORENCE, trust_remote_code=True, torch_dtype=dtype
|
|
|
|
|
|
|
34 |
)
|
35 |
+
florence_model.to("cpu").eval()
|
36 |
+
florence_processor = AutoProcessor.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True)
|
37 |
+
|
38 |
+
# Stable Diffusion ํ์ดํ๋ผ์ธ
|
39 |
+
pipe = DiffusionPipeline.from_pretrained(
|
40 |
+
LOCAL_TURBOX,
|
41 |
+
torch_dtype=dtype,
|
42 |
+
trust_remote_code=True,
|
43 |
+
safety_checker=None,
|
44 |
+
feature_extractor=None
|
45 |
)
|
46 |
+
pipe = pipe.to(device)
|
47 |
+
pipe.scheduler = scheduler
|
48 |
+
pipe.enable_attention_slicing() # ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
|
49 |
|
50 |
+
# ์์
|
51 |
MAX_SEED = 2**31 - 1
|
52 |
|
53 |
+
# ํ
์คํธ ์คํ์ผ๋ฌ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def pseudo_translate_to_korean_style(en_prompt: str) -> str:
|
55 |
+
return f"Cartoon styled {en_prompt} handsome or pretty people"
|
56 |
|
57 |
+
# ํ๋กฌํํธ ์์ฑ
|
58 |
def generate_prompt(image):
|
59 |
+
if not isinstance(image, Image.Image):
|
60 |
+
image = Image.fromarray(image)
|
61 |
+
|
62 |
+
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to("cpu")
|
63 |
+
with torch.no_grad():
|
64 |
+
generated_ids = florence_model.generate(
|
65 |
+
input_ids=inputs["input_ids"],
|
66 |
+
pixel_values=inputs["pixel_values"],
|
67 |
+
max_new_tokens=256,
|
68 |
+
num_beams=3
|
69 |
+
)
|
70 |
+
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
71 |
+
parsed_answer = florence_processor.post_process_generation(
|
72 |
+
generated_text,
|
73 |
+
task="<MORE_DETAILED_CAPTION>",
|
74 |
+
image_size=(image.width, image.height)
|
75 |
)
|
76 |
+
prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
|
77 |
+
cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
|
78 |
+
return cartoon_prompt
|
|
|
79 |
|
80 |
+
# ์ด๋ฏธ์ง ์์ฑ ํจ์
|
81 |
def generate_image(prompt, seed=42, randomize_seed=False):
|
|
|
82 |
if randomize_seed:
|
83 |
seed = random.randint(0, MAX_SEED)
|
84 |
+
generator = torch.Generator().manual_seed(seed)
|
|
|
85 |
image = pipe(
|
86 |
prompt=prompt,
|
87 |
+
guidance_scale=1.5,
|
88 |
+
num_inference_steps=6, # ์ต์ ํ๋ step ์
|
89 |
width=512,
|
90 |
height=512,
|
91 |
generator=generator
|
92 |
).images[0]
|
93 |
return image, seed
|
94 |
|
95 |
+
# Gradio UI
|
96 |
with gr.Blocks() as demo:
|
97 |
+
gr.Markdown("# ๐ผ ์ด๋ฏธ์ง โ ์ค๋ช
์์ฑ โ ์นดํฐ ์ด๋ฏธ์ง ์๋ ์์ฑ๊ธฐ")
|
98 |
+
gr.Markdown("**๐ ์ฌ์ฉ๋ฒ ์๋ด (ํ๊ตญ์ด)**\n"
|
99 |
+
"- ์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ๋ฉด AI๊ฐ ์ค๋ช
โ ์คํ์ผ ๋ณํ โ ์นดํฐ ์ด๋ฏธ์ง ์์ฑ๊น์ง ์๋์ผ๋ก ์ํํฉ๋๋ค.")
|
100 |
+
|
|
|
|
|
|
|
101 |
with gr.Row():
|
102 |
with gr.Column():
|
103 |
+
input_img = gr.Image(label="๐จ ์๋ณธ ์ด๋ฏธ์ง ์
๋ก๋")
|
104 |
+
run_button = gr.Button("โจ ์์ฑ ์์")
|
105 |
+
|
106 |
with gr.Column():
|
107 |
+
prompt_out = gr.Textbox(label="๐ ์คํ์ผ ์ ์ฉ๋ ํ๋กฌํํธ", lines=3, show_copy_button=True)
|
108 |
+
output_img = gr.Image(label="๐ ์์ฑ๋ ์ด๋ฏธ์ง")
|
109 |
|
|
|
110 |
def full_process(img):
|
111 |
prompt = generate_prompt(img)
|
112 |
image, seed = generate_image(prompt, randomize_seed=True)
|
|
|
114 |
|
115 |
run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
|
116 |
|
117 |
+
demo.launch()
|