vivaceailab commited on
Commit
2cc6477
ยท
verified ยท
1 Parent(s): ad759df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -93
app.py CHANGED
@@ -1,131 +1,112 @@
1
  import os
 
 
 
 
2
  from huggingface_hub import snapshot_download
3
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
 
 
4
 
5
- # ๋ชจ๋ธ ๋ฆฌ๋น„์ „๊ณผ ๋กœ์ปฌ ์ €์žฅ ๊ฒฝ๋กœ ์„ค์ •
 
6
  REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
7
- LOCAL_FLORENCE = snapshot_download(
8
- repo_id="microsoft/Florence-2-base",
9
- revision=REVISION
10
- )
11
- LOCAL_TURBOX = snapshot_download(
12
- repo_id="tensorart/stable-diffusion-3.5-large-TurboX"
13
- )
14
 
15
- import sys, types, importlib.machinery, importlib
16
- # flash_attn ๋ฌดํšจํ™” ์ฒ˜๋ฆฌ
17
- spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
18
- mod = types.ModuleType('flash_attn')
19
- mod.__spec__ = spec
20
- sys.modules['flash_attn'] = mod
21
 
22
- import gradio as gr
23
- import torch
24
- import random
25
- from PIL import Image
26
- from transformers import AutoProcessor, AutoModelForCausalLM
27
- from transformers import ( CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor)
28
- from diffusers import DiffusionPipeline, EulerDiscreteScheduler
29
 
30
- # 1. ๊ฒฝ๋Ÿ‰ํ™” ์˜ต์…˜: FP16 + 8bit ์–‘์žํ™” ์ ์šฉ
31
- model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
32
- pipe = DiffusionPipeline.from_pretrained(
33
- model_repo,
34
- torch_dtype=torch.float16,
35
- load_in_8bit=True,
36
- device_map="auto",
37
- safety_checker=None,
38
- feature_extractor=None
39
- )
40
- pipe = pipe.to("cuda") # GPU ์‚ฌ์šฉ
41
- # ์Šค์ผ€์ค„๋Ÿฌ ๋กœ๋“œ (Euler ๋ฐฉ์‹)
42
- pipe.scheduler = EulerDiscreteScheduler.from_pretrained(
43
- model_repo, subfolder="scheduler", local_files_only=True
44
  )
 
 
 
 
45
 
46
- # Florence ๋ชจ๋ธ ์„ค์ • (CPU ๋กœ๋“œ ํ›„ ํ•„์š” ์‹œ GPU ์ด๋™)
47
  florence_model = AutoModelForCausalLM.from_pretrained(
48
- LOCAL_FLORENCE,
49
- trust_remote_code=True,
50
- torch_dtype=torch.float16,
51
- load_in_8bit=True # ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ชจ๋ธ๋„ 8bit ๋กœ๋“œ
52
  )
53
- florence_model = florence_model.to("cpu")
54
- florence_model.eval()
55
- florence_processor = AutoProcessor.from_pretrained(
56
- LOCAL_FLORENCE,
57
- trust_remote_code=True
 
 
 
 
 
58
  )
 
 
 
59
 
60
- # ์ตœ๋Œ€ ์‹œ๋“œ ๊ฐ’
61
  MAX_SEED = 2**31 - 1
62
 
63
- def preprocess_image(input_image: Image.Image) -> Image.Image:
64
- """
65
- ์ž…๋ ฅ ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ 512x512๋กœ ๋ฆฌ์‚ฌ์ด์ง•
66
- """
67
- img = input_image.convert("RGB")
68
- img = img.resize((512, 512), resample=Image.LANCZOS) # ๊ณ ํ’ˆ์งˆ ๋ฆฌ์‚ฌ์ด์ง•
69
- return img
70
-
71
- # ์˜์–ด ์„ค๋ช…์„ ์นดํˆฐ ์Šคํƒ€์ผ ํ•œ๊ตญ์–ด ํ”„๋กฌํ”„ํŠธ๋กœ ๋ณ€ํ™˜
72
  def pseudo_translate_to_korean_style(en_prompt: str) -> str:
73
- return f"์นดํˆฐ ์Šคํƒ€์ผ: {en_prompt} ์•„๋ฆ„๋‹ค์šด ์ธ๋ฌผ"
74
 
75
- # ์ด๋ฏธ์ง€ โ†’ ์ƒ์„ธ ์„ค๋ช… โ†’ ์นดํˆฐ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
76
  def generate_prompt(image):
77
- img = preprocess_image(image)
78
- inputs = florence_processor(
79
- text="<MORE_DETAILED_CAPTION>",
80
- images=img,
81
- return_tensors="pt"
82
- ).to(pipe.device)
83
- generated_ids = florence_model.generate(
84
- input_ids=inputs["input_ids"],
85
- pixel_values=inputs["pixel_values"],
86
- max_new_tokens=256, # ํ† ํฐ ์ˆ˜ ์ œํ•œ
87
- num_beams=2 # ๋น” ์„œ์น˜ ๊ฐœ์ˆ˜ ์ถ•์†Œ
 
 
 
 
 
88
  )
89
- generated_text = florence_processor.batch_decode(
90
- generated_ids, skip_special_tokens=True
91
- )[0]
92
- return pseudo_translate_to_korean_style(generated_text)
93
 
94
- # ํ”„๋กฌํ”„ํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
95
  def generate_image(prompt, seed=42, randomize_seed=False):
96
- # ์‹œ๋“œ ๋žœ๋คํ™” ์˜ต์…˜
97
  if randomize_seed:
98
  seed = random.randint(0, MAX_SEED)
99
- generator = torch.Generator(device="cuda").manual_seed(seed)
100
- # ํ•ด์ƒ๋„๋ฅผ 512x512๋กœ ์„ค์ •ํ•ด ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
101
  image = pipe(
102
  prompt=prompt,
103
- guidance_scale=1.2,
104
- num_inference_steps=8,
105
  width=512,
106
  height=512,
107
  generator=generator
108
  ).images[0]
109
  return image, seed
110
 
111
- # Gradio UI ๊ตฌ์„ฑ
112
  with gr.Blocks() as demo:
113
- gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ƒ์„ฑ๊ธฐ")
114
- gr.Markdown(
115
- "**์‚ฌ์šฉ๋ฒ• (ํ•œ๊ธ€ ์„ค๋ช…)**\n"
116
- "1. ์™ผ์ชฝ์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.\n"
117
- "2. AI๊ฐ€ ์ƒ์„ธ ์„ค๋ช…์„ ์ƒ์„ฑํ•˜๊ณ  ์นดํˆฐ ์Šคํƒ€์ผ ํ•œ๊ตญ์–ด ํ”„๋กฌํ”„ํŠธ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.\n"
118
- "3. ์˜ค๋ฅธ์ชฝ์— ์ƒ์„ฑ๋œ ์นดํˆฐ ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•˜์„ธ์š”."
119
- )
120
  with gr.Row():
121
  with gr.Column():
122
- input_img = gr.Image(type="pil", label="์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ") # ํƒ€์ž…: PIL ์ด๋ฏธ์ง€
123
- run_button = gr.Button("์ƒ์„ฑ ์‹œ์ž‘")
 
124
  with gr.Column():
125
- prompt_out = gr.Textbox(label="์ƒ์„ฑ๋œ ํ”„๋กฌํ”„ํŠธ", lines=2, show_copy_button=True)
126
- output_img = gr.Image(label="์ƒ์„ฑ๋œ ์นดํˆฐ ์ด๋ฏธ์ง€")
127
 
128
- # ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ „์ฒด ํ”„๋กœ์„ธ์Šค ์‹คํ–‰
129
  def full_process(img):
130
  prompt = generate_prompt(img)
131
  image, seed = generate_image(prompt, randomize_seed=True)
@@ -133,4 +114,4 @@ with gr.Blocks() as demo:
133
 
134
  run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
135
 
136
- demo.launch(share=True) # Hugging Face Spaces์— ๋ฐฐํฌ ์‹œ share=True ์‚ฌ์šฉ
 
1
  import os
2
+ import torch
3
+ import random
4
+ import importlib
5
+ from PIL import Image
6
  from huggingface_hub import snapshot_download
7
+ import gradio as gr
8
+ from transformers import AutoProcessor, AutoModelForCausalLM, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
9
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline, EulerDiscreteScheduler, UNet2DConditionModel
10
 
11
+ # ํ™˜๊ฒฝ ์„ค์ •
12
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
13
  REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
 
 
 
 
 
 
 
14
 
15
+ # ๋กœ์ปฌ ๋‹ค์šด๋กœ๋“œ
16
+ LOCAL_FLORENCE = snapshot_download("microsoft/Florence-2-base", revision=REVISION)
17
+ LOCAL_TURBOX = snapshot_download("tensorart/stable-diffusion-3.5-large-TurboX")
 
 
 
18
 
19
+ # ๋””๋ฐ”์ด์Šค ๋ฐ dtype ์„ค์ •
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
22
 
23
+ # ๋ชจ๋ธ ๋กœ๋”ฉ (๋ถ€๋ถ„๋ณ„ ๋กœ๋”ฉ + dtype ์ ์šฉ)
24
+ scheduler = EulerDiscreteScheduler.from_pretrained(
25
+ LOCAL_TURBOX, subfolder="scheduler", torch_dtype=dtype
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
+ text_encoder = CLIPTextModel.from_pretrained(LOCAL_TURBOX, subfolder="text_encoder", torch_dtype=dtype)
28
+ tokenizer = CLIPTokenizer.from_pretrained(LOCAL_TURBOX, subfolder="tokenizer")
29
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="feature_extractor")
30
+ unet = UNet2DConditionModel.from_pretrained(LOCAL_TURBOX, subfolder="unet", torch_dtype=dtype)
31
 
 
32
  florence_model = AutoModelForCausalLM.from_pretrained(
33
+ LOCAL_FLORENCE, trust_remote_code=True, torch_dtype=dtype
 
 
 
34
  )
35
+ florence_model.to("cpu").eval()
36
+ florence_processor = AutoProcessor.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True)
37
+
38
+ # Stable Diffusion ํŒŒ์ดํ”„๋ผ์ธ
39
+ pipe = DiffusionPipeline.from_pretrained(
40
+ LOCAL_TURBOX,
41
+ torch_dtype=dtype,
42
+ trust_remote_code=True,
43
+ safety_checker=None,
44
+ feature_extractor=None
45
  )
46
+ pipe = pipe.to(device)
47
+ pipe.scheduler = scheduler
48
+ pipe.enable_attention_slicing() # ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
49
 
50
+ # ์ƒ์ˆ˜
51
  MAX_SEED = 2**31 - 1
52
 
53
+ # ํ…์ŠคํŠธ ์Šคํƒ€์ผ๋Ÿฌ
 
 
 
 
 
 
 
 
54
  def pseudo_translate_to_korean_style(en_prompt: str) -> str:
55
+ return f"Cartoon styled {en_prompt} handsome or pretty people"
56
 
57
+ # ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
58
  def generate_prompt(image):
59
+ if not isinstance(image, Image.Image):
60
+ image = Image.fromarray(image)
61
+
62
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to("cpu")
63
+ with torch.no_grad():
64
+ generated_ids = florence_model.generate(
65
+ input_ids=inputs["input_ids"],
66
+ pixel_values=inputs["pixel_values"],
67
+ max_new_tokens=256,
68
+ num_beams=3
69
+ )
70
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
71
+ parsed_answer = florence_processor.post_process_generation(
72
+ generated_text,
73
+ task="<MORE_DETAILED_CAPTION>",
74
+ image_size=(image.width, image.height)
75
  )
76
+ prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
77
+ cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
78
+ return cartoon_prompt
 
79
 
80
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜
81
  def generate_image(prompt, seed=42, randomize_seed=False):
 
82
  if randomize_seed:
83
  seed = random.randint(0, MAX_SEED)
84
+ generator = torch.Generator().manual_seed(seed)
 
85
  image = pipe(
86
  prompt=prompt,
87
+ guidance_scale=1.5,
88
+ num_inference_steps=6, # ์ตœ์ ํ™”๋œ step ์ˆ˜
89
  width=512,
90
  height=512,
91
  generator=generator
92
  ).images[0]
93
  return image, seed
94
 
95
+ # Gradio UI
96
  with gr.Blocks() as demo:
97
+ gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… ์ƒ์„ฑ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ž๋™ ์ƒ์„ฑ๊ธฐ")
98
+ gr.Markdown("**๐Ÿ“Œ ์‚ฌ์šฉ๋ฒ• ์•ˆ๋‚ด (ํ•œ๊ตญ์–ด)**\n"
99
+ "- ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด AI๊ฐ€ ์„ค๋ช… โ†’ ์Šคํƒ€์ผ ๋ณ€ํ™˜ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ƒ์„ฑ๊นŒ์ง€ ์ž๋™์œผ๋กœ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
100
+
 
 
 
101
  with gr.Row():
102
  with gr.Column():
103
+ input_img = gr.Image(label="๐ŸŽจ ์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
104
+ run_button = gr.Button("โœจ ์ƒ์„ฑ ์‹œ์ž‘")
105
+
106
  with gr.Column():
107
+ prompt_out = gr.Textbox(label="๐Ÿ“ ์Šคํƒ€์ผ ์ ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ", lines=3, show_copy_button=True)
108
+ output_img = gr.Image(label="๐ŸŽ‰ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
109
 
 
110
  def full_process(img):
111
  prompt = generate_prompt(img)
112
  image, seed = generate_image(prompt, randomize_seed=True)
 
114
 
115
  run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
116
 
117
+ demo.launch()