vivaceailab commited on
Commit
9e1bace
·
verified ·
1 Parent(s): 2cc6477

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -110
app.py CHANGED
@@ -1,117 +1,148 @@
1
- import os
2
  import torch
3
- import random
4
- import importlib
5
- from PIL import Image
6
- from huggingface_hub import snapshot_download
7
  import gradio as gr
8
- from transformers import AutoProcessor, AutoModelForCausalLM, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
9
- from diffusers import StableDiffusionPipeline, DiffusionPipeline, EulerDiscreteScheduler, UNet2DConditionModel
10
 
11
- # 환경 설정
12
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
13
- REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
14
-
15
- # 로컬 다운로드
16
- LOCAL_FLORENCE = snapshot_download("microsoft/Florence-2-base", revision=REVISION)
17
- LOCAL_TURBOX = snapshot_download("tensorart/stable-diffusion-3.5-large-TurboX")
18
-
19
- # 디바이스 및 dtype 설정
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
 
23
- # 모델 로딩 (부분별 로딩 + dtype 적용)
24
- scheduler = EulerDiscreteScheduler.from_pretrained(
25
- LOCAL_TURBOX, subfolder="scheduler", torch_dtype=dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
- text_encoder = CLIPTextModel.from_pretrained(LOCAL_TURBOX, subfolder="text_encoder", torch_dtype=dtype)
28
- tokenizer = CLIPTokenizer.from_pretrained(LOCAL_TURBOX, subfolder="tokenizer")
29
- feature_extractor = CLIPFeatureExtractor.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="feature_extractor")
30
- unet = UNet2DConditionModel.from_pretrained(LOCAL_TURBOX, subfolder="unet", torch_dtype=dtype)
31
 
32
- florence_model = AutoModelForCausalLM.from_pretrained(
33
- LOCAL_FLORENCE, trust_remote_code=True, torch_dtype=dtype
34
- )
35
- florence_model.to("cpu").eval()
36
- florence_processor = AutoProcessor.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True)
37
-
38
- # Stable Diffusion 파이프라인
39
- pipe = DiffusionPipeline.from_pretrained(
40
- LOCAL_TURBOX,
41
- torch_dtype=dtype,
42
- trust_remote_code=True,
43
- safety_checker=None,
44
- feature_extractor=None
45
- )
46
- pipe = pipe.to(device)
47
- pipe.scheduler = scheduler
48
- pipe.enable_attention_slicing() # 메모리 절약
49
-
50
- # 상수
51
- MAX_SEED = 2**31 - 1
52
-
53
- # 텍스트 스타일러
54
- def pseudo_translate_to_korean_style(en_prompt: str) -> str:
55
- return f"Cartoon styled {en_prompt} handsome or pretty people"
56
-
57
- # 프롬프트 생성
58
- def generate_prompt(image):
59
- if not isinstance(image, Image.Image):
60
- image = Image.fromarray(image)
61
-
62
- inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to("cpu")
63
- with torch.no_grad():
64
- generated_ids = florence_model.generate(
65
- input_ids=inputs["input_ids"],
66
- pixel_values=inputs["pixel_values"],
67
- max_new_tokens=256,
68
- num_beams=3
69
- )
70
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
71
- parsed_answer = florence_processor.post_process_generation(
72
- generated_text,
73
- task="<MORE_DETAILED_CAPTION>",
74
- image_size=(image.width, image.height)
75
- )
76
- prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
77
- cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
78
- return cartoon_prompt
79
-
80
- # 이미지 생성 함수
81
- def generate_image(prompt, seed=42, randomize_seed=False):
82
- if randomize_seed:
83
- seed = random.randint(0, MAX_SEED)
84
- generator = torch.Generator().manual_seed(seed)
85
- image = pipe(
86
- prompt=prompt,
87
- guidance_scale=1.5,
88
- num_inference_steps=6, # 최적화된 step
89
- width=512,
90
- height=512,
91
- generator=generator
92
- ).images[0]
93
- return image, seed
94
-
95
- # Gradio UI
96
- with gr.Blocks() as demo:
97
- gr.Markdown("# 🖼 이미지 설명 생성 → 카툰 이미지 자동 생성기")
98
- gr.Markdown("**📌 사용법 안내 (한국어)**\n"
99
- "- 이미지를 업로드하면 AI가 설명 → 스타일 변환 → 카툰 이미지 생성까지 자동으로 수행합니다.")
100
-
101
- with gr.Row():
102
- with gr.Column():
103
- input_img = gr.Image(label="🎨 원본 이미지 업로드")
104
- run_button = gr.Button("✨ 생성 시작")
105
-
106
- with gr.Column():
107
- prompt_out = gr.Textbox(label="📝 스타일 적용된 프롬프트", lines=3, show_copy_button=True)
108
- output_img = gr.Image(label="🎉 생성된 이미지")
109
-
110
- def full_process(img):
111
- prompt = generate_prompt(img)
112
- image, seed = generate_image(prompt, randomize_seed=True)
113
- return prompt, image
114
-
115
- run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
116
-
117
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from diffusers import StableDiffusionPipeline
 
 
 
3
  import gradio as gr
 
 
4
 
5
+ # GPU 사용 가능 여부 확인
 
 
 
 
 
 
 
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
7
 
8
+ # 파이프라인 로딩
9
+ pipe = StableDiffusionPipeline.from_pretrained(
10
+ "runwayml/stable-diffusion-v1-5",
11
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
12
+ ).to(device)
13
+
14
+ # 생성 함수
15
+ def generate(prompt):
16
+ image = pipe(prompt).images[0]
17
+ return image
18
+
19
+ # Gradio 인터페이스 정의
20
+ interface = gr.Interface(
21
+ fn=generate,
22
+ inputs=gr.Textbox(label="프롬프트를 입력하세요", placeholder="예: a cute caricature of a cat in a hat"),
23
+ outputs=gr.Image(type="pil"),
24
+ title="Text to Image - Stable Diffusion",
25
+ description="Stable Diffusion을 사용한 텍스트-이미지 생성기입니다."
26
  )
 
 
 
 
27
 
28
+ if __name__ == "__main__":
29
+ interface.launch()
30
+
31
+
32
+ # import os
33
+ # import torch
34
+ # import random
35
+ # import importlib
36
+ # from PIL import Image
37
+ # from huggingface_hub import snapshot_download
38
+ # import gradio as gr
39
+ # from transformers import AutoProcessor, AutoModelForCausalLM, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
40
+ # from diffusers import StableDiffusionPipeline, DiffusionPipeline, EulerDiscreteScheduler, UNet2DConditionModel
41
+
42
+ # # 환경 설정
43
+ # os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
44
+ # REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
45
+
46
+ # # 로컬 다운로드
47
+ # LOCAL_FLORENCE = snapshot_download("microsoft/Florence-2-base", revision=REVISION)
48
+ # LOCAL_TURBOX = snapshot_download("tensorart/stable-diffusion-3.5-large-TurboX")
49
+
50
+ # # 디바이스 dtype 설정
51
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ # dtype = torch.float16 if torch.cuda.is_available() else torch.float32
53
+
54
+ # # 모델 로딩 (부분별 로딩 + dtype 적용)
55
+ # scheduler = EulerDiscreteScheduler.from_pretrained(
56
+ # LOCAL_TURBOX, subfolder="scheduler", torch_dtype=dtype
57
+ # )
58
+ # text_encoder = CLIPTextModel.from_pretrained(LOCAL_TURBOX, subfolder="text_encoder", torch_dtype=dtype)
59
+ # tokenizer = CLIPTokenizer.from_pretrained(LOCAL_TURBOX, subfolder="tokenizer")
60
+ # feature_extractor = CLIPFeatureExtractor.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="feature_extractor")
61
+ # unet = UNet2DConditionModel.from_pretrained(LOCAL_TURBOX, subfolder="unet", torch_dtype=dtype)
62
+
63
+ # florence_model = AutoModelForCausalLM.from_pretrained(
64
+ # LOCAL_FLORENCE, trust_remote_code=True, torch_dtype=dtype
65
+ # )
66
+ # florence_model.to("cpu").eval()
67
+ # florence_processor = AutoProcessor.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True)
68
+
69
+ # # Stable Diffusion 파이프라인
70
+ # pipe = DiffusionPipeline.from_pretrained(
71
+ # LOCAL_TURBOX,
72
+ # torch_dtype=dtype,
73
+ # trust_remote_code=True,
74
+ # safety_checker=None,
75
+ # feature_extractor=None
76
+ # )
77
+ # pipe = pipe.to(device)
78
+ # pipe.scheduler = scheduler
79
+ # pipe.enable_attention_slicing() # 메모리 절약
80
+
81
+ # # 상수
82
+ # MAX_SEED = 2**31 - 1
83
+
84
+ # # 텍스트 스타일러
85
+ # def pseudo_translate_to_korean_style(en_prompt: str) -> str:
86
+ # return f"Cartoon styled {en_prompt} handsome or pretty people"
87
+
88
+ # # 프롬프트 생성
89
+ # def generate_prompt(image):
90
+ # if not isinstance(image, Image.Image):
91
+ # image = Image.fromarray(image)
92
+
93
+ # inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to("cpu")
94
+ # with torch.no_grad():
95
+ # generated_ids = florence_model.generate(
96
+ # input_ids=inputs["input_ids"],
97
+ # pixel_values=inputs["pixel_values"],
98
+ # max_new_tokens=256,
99
+ # num_beams=3
100
+ # )
101
+ # generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
102
+ # parsed_answer = florence_processor.post_process_generation(
103
+ # generated_text,
104
+ # task="<MORE_DETAILED_CAPTION>",
105
+ # image_size=(image.width, image.height)
106
+ # )
107
+ # prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
108
+ # cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
109
+ # return cartoon_prompt
110
+
111
+ # # 이미지 생성 함수
112
+ # def generate_image(prompt, seed=42, randomize_seed=False):
113
+ # if randomize_seed:
114
+ # seed = random.randint(0, MAX_SEED)
115
+ # generator = torch.Generator().manual_seed(seed)
116
+ # image = pipe(
117
+ # prompt=prompt,
118
+ # guidance_scale=1.5,
119
+ # num_inference_steps=6, # 최적화된 step 수
120
+ # width=512,
121
+ # height=512,
122
+ # generator=generator
123
+ # ).images[0]
124
+ # return image, seed
125
+
126
+ # # Gradio UI
127
+ # with gr.Blocks() as demo:
128
+ # gr.Markdown("# 🖼 이미지 → 설명 생성 → 카툰 이미지 자동 생성기")
129
+ # gr.Markdown("**📌 사용법 안내 (한국어)**\n"
130
+ # "- 이미지를 업로드하면 AI가 설명 → 스타일 변환 → 카툰 이미지 생성까지 자동으로 수행합니다.")
131
+
132
+ # with gr.Row():
133
+ # with gr.Column():
134
+ # input_img = gr.Image(label="🎨 원본 이미지 업로드")
135
+ # run_button = gr.Button("✨ 생성 시작")
136
+
137
+ # with gr.Column():
138
+ # prompt_out = gr.Textbox(label="📝 스타일 적용된 프롬프트", lines=3, show_copy_button=True)
139
+ # output_img = gr.Image(label="🎉 생성된 이미지")
140
+
141
+ # def full_process(img):
142
+ # prompt = generate_prompt(img)
143
+ # image, seed = generate_image(prompt, randomize_seed=True)
144
+ # return prompt, image
145
+
146
+ # run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
147
+
148
+ # demo.launch()