ginipick commited on
Commit
2e394d9
ยท
verified ยท
1 Parent(s): 39b272a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -56
app.py CHANGED
@@ -9,59 +9,58 @@ from diffusers import DiffusionPipeline
9
  from custom_pipeline import FLUXPipelineWithIntermediateOutputs
10
  from transformers import pipeline
11
 
12
- # Translation model loading with device specification
13
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
14
 
15
- # Constants
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 2048
18
  DEFAULT_WIDTH = 1024
19
  DEFAULT_HEIGHT = 1024
20
  DEFAULT_INFERENCE_STEPS = 1
21
- GPU_DURATION = 15 # Reduced from 25 to stay within quota
22
 
23
- # Device and model setup with memory optimization
24
  def setup_model():
25
  dtype = torch.float16
26
  pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
27
  "black-forest-labs/FLUX.1-schnell",
28
- torch_dtype=dtype,
29
- device_map="auto" # Enable model parallelism
30
- )
31
  return pipe
32
 
33
  pipe = setup_model()
34
 
35
- # Menu labels dictionary
36
- english_labels = {
37
- "Generated Image": "Generated Image",
38
- "Prompt": "Prompt",
39
- "Enhance Image": "Enhance Image",
40
- "Advanced Options": "Advanced Options",
41
- "Seed": "Seed",
42
- "Randomize Seed": "Randomize Seed",
43
- "Width": "Width",
44
- "Height": "Height",
45
- "Inference Steps": "Inference Steps",
46
- "Inspiration Gallery": "Inspiration Gallery"
47
  }
48
 
49
  def translate_if_korean(text):
50
- """Safely translate Korean text to English."""
51
  try:
52
  if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
53
  return translator(text)[0]['translation_text']
54
  return text
55
  except Exception as e:
56
- print(f"Translation error: {e}")
57
  return text
58
 
59
- # Modified inference function with error handling and memory management
60
  @spaces.GPU(duration=GPU_DURATION)
61
  def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
62
  randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
63
  try:
64
- # Input validation
65
  if not isinstance(seed, (int, type(None))):
66
  seed = None
67
  randomize_seed = True
@@ -71,7 +70,7 @@ def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT
71
  if seed is None or randomize_seed:
72
  seed = random.randint(0, MAX_SEED)
73
 
74
- # Ensure valid dimensions
75
  width = min(max(256, width), MAX_IMAGE_SIZE)
76
  height = min(max(256, height), MAX_IMAGE_SIZE)
77
 
@@ -79,7 +78,7 @@ def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT
79
 
80
  start_time = time.time()
81
 
82
- with torch.cuda.amp.autocast(): # Enable automatic mixed precision
83
  for img in pipe.generate_images(
84
  prompt=prompt,
85
  guidance_scale=0,
@@ -88,26 +87,25 @@ def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT
88
  height=height,
89
  generator=generator
90
  ):
91
- latency = f"Processing Time: {(time.time()-start_time):.2f} seconds"
92
 
93
- # Clear CUDA cache after generation
94
  if torch.cuda.is_available():
95
  torch.cuda.empty_cache()
96
 
97
  yield img, seed, latency
98
 
99
  except Exception as e:
100
- print(f"Error in generate_image: {e}")
101
- # Return a blank image or error message
102
- yield None, seed, f"Error: {str(e)}"
103
 
104
- # Example generator with error handling
105
  def generate_example_image(prompt):
106
  try:
107
  return next(generate_image(prompt, randomize_seed=True))
108
  except Exception as e:
109
- print(f"Error in example generation: {e}")
110
- return None, None, f"Error: {str(e)}"
111
 
112
  # Example prompts
113
  examples = [
@@ -119,63 +117,63 @@ examples = [
119
  "A cosmic coffee shop where baristas are constellations serving drinks made of stardust"
120
  ]
121
 
 
122
  css = """
123
  footer {
124
  visibility: hidden;
125
  }
126
  """
127
 
128
- # --- Gradio UI with improved error handling ---
129
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
130
  with gr.Column(elem_id="app-container"):
131
  with gr.Row():
132
  with gr.Column(scale=3):
133
- result = gr.Image(label=english_labels["Generated Image"],
134
  show_label=False,
135
  interactive=False)
136
  with gr.Column(scale=1):
137
  prompt = gr.Text(
138
- label=english_labels["Prompt"],
139
- placeholder="Describe the image you want to generate...",
140
  lines=3,
141
  show_label=False,
142
  container=False,
143
  )
144
- enhanceBtn = gr.Button(f"๐Ÿš€ {english_labels['Enhance Image']}")
145
 
146
- with gr.Column(english_labels["Advanced Options"]):
147
  with gr.Row():
148
  latency = gr.Text(show_label=False)
149
  with gr.Row():
150
- # Modified Number component with proper validation
151
  seed = gr.Number(
152
- label=english_labels["Seed"],
153
  value=42,
154
  precision=0,
155
  minimum=0,
156
  maximum=MAX_SEED
157
  )
158
  randomize_seed = gr.Checkbox(
159
- label=english_labels["Randomize Seed"],
160
  value=True
161
  )
162
  with gr.Row():
163
  width = gr.Slider(
164
- label=english_labels["Width"],
165
  minimum=256,
166
  maximum=MAX_IMAGE_SIZE,
167
  step=32,
168
  value=DEFAULT_WIDTH
169
  )
170
  height = gr.Slider(
171
- label=english_labels["Height"],
172
  minimum=256,
173
  maximum=MAX_IMAGE_SIZE,
174
  step=32,
175
  value=DEFAULT_HEIGHT
176
  )
177
  num_inference_steps = gr.Slider(
178
- label=english_labels["Inference Steps"],
179
  minimum=1,
180
  maximum=4,
181
  step=1,
@@ -183,7 +181,7 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
183
  )
184
 
185
  with gr.Row():
186
- gr.Markdown(f"### ๐ŸŒŸ {english_labels['Inspiration Gallery']}")
187
  with gr.Row():
188
  gr.Examples(
189
  examples=examples,
@@ -193,7 +191,14 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
193
  cache_examples=False
194
  )
195
 
196
- # Event handling with improved error handling
 
 
 
 
 
 
 
197
  enhanceBtn.click(
198
  fn=generate_image,
199
  inputs=[prompt, seed, width, height],
@@ -203,14 +208,6 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
203
  queue=False
204
  )
205
 
206
- # Modified event handler with proper input validation
207
- def validated_generate(*args):
208
- try:
209
- return next(generate_image(*args))
210
- except Exception as e:
211
- print(f"Error in validated_generate: {e}")
212
- return None, args[1], f"Error: {str(e)}"
213
-
214
  gr.on(
215
  triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
216
  fn=validated_generate,
@@ -223,4 +220,5 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
223
  )
224
 
225
  if __name__ == "__main__":
226
- demo.launch()
 
 
9
  from custom_pipeline import FLUXPipelineWithIntermediateOutputs
10
  from transformers import pipeline
11
 
12
+ # ๋ฒˆ์—ญ ๋ชจ๋ธ ์„ค์ • (CPU ์‚ฌ์šฉ)
13
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
14
 
15
+ # ์ƒ์ˆ˜ ์ •์˜
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 2048
18
  DEFAULT_WIDTH = 1024
19
  DEFAULT_HEIGHT = 1024
20
  DEFAULT_INFERENCE_STEPS = 1
21
+ GPU_DURATION = 15 # GPU ํ• ๋‹น ์‹œ๊ฐ„ ์ถ•์†Œ
22
 
23
+ # ๋ชจ๋ธ ์„ค์ •
24
  def setup_model():
25
  dtype = torch.float16
26
  pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
27
  "black-forest-labs/FLUX.1-schnell",
28
+ torch_dtype=dtype
29
+ ).to("cuda")
 
30
  return pipe
31
 
32
  pipe = setup_model()
33
 
34
+ # ๋ฉ”๋‰ด ๋ ˆ์ด๋ธ”
35
+ labels = {
36
+ "Generated Image": "์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€",
37
+ "Prompt": "ํ”„๋กฌํ”„ํŠธ",
38
+ "Enhance Image": "์ด๋ฏธ์ง€ ํ–ฅ์ƒ",
39
+ "Advanced Options": "๊ณ ๊ธ‰ ์„ค์ •",
40
+ "Seed": "์‹œ๋“œ",
41
+ "Randomize Seed": "๋žœ๋ค ์‹œ๋“œ",
42
+ "Width": "๋„ˆ๋น„",
43
+ "Height": "๋†’์ด",
44
+ "Inference Steps": "์ถ”๋ก  ๋‹จ๊ณ„",
45
+ "Inspiration Gallery": "์˜๊ฐ ๊ฐค๋Ÿฌ๋ฆฌ"
46
  }
47
 
48
  def translate_if_korean(text):
49
+ """ํ•œ๊ธ€ ํ…์ŠคํŠธ๋ฅผ ์˜์–ด๋กœ ์•ˆ์ „ํ•˜๊ฒŒ ๋ฒˆ์—ญ"""
50
  try:
51
  if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
52
  return translator(text)[0]['translation_text']
53
  return text
54
  except Exception as e:
55
+ print(f"๋ฒˆ์—ญ ์˜ค๋ฅ˜: {e}")
56
  return text
57
 
58
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜
59
  @spaces.GPU(duration=GPU_DURATION)
60
  def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
61
  randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
62
  try:
63
+ # ์ž…๋ ฅ๊ฐ’ ๊ฒ€์ฆ
64
  if not isinstance(seed, (int, type(None))):
65
  seed = None
66
  randomize_seed = True
 
70
  if seed is None or randomize_seed:
71
  seed = random.randint(0, MAX_SEED)
72
 
73
+ # ํฌ๊ธฐ ์œ ํšจ์„ฑ ๊ฒ€์‚ฌ
74
  width = min(max(256, width), MAX_IMAGE_SIZE)
75
  height = min(max(256, height), MAX_IMAGE_SIZE)
76
 
 
78
 
79
  start_time = time.time()
80
 
81
+ with torch.cuda.amp.autocast():
82
  for img in pipe.generate_images(
83
  prompt=prompt,
84
  guidance_scale=0,
 
87
  height=height,
88
  generator=generator
89
  ):
90
+ latency = f"์ฒ˜๋ฆฌ ์‹œ๊ฐ„: {(time.time()-start_time):.2f} ์ดˆ"
91
 
92
+ # CUDA ์บ์‹œ ์ •๋ฆฌ
93
  if torch.cuda.is_available():
94
  torch.cuda.empty_cache()
95
 
96
  yield img, seed, latency
97
 
98
  except Exception as e:
99
+ print(f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
100
+ yield None, seed, f"์˜ค๋ฅ˜: {str(e)}"
 
101
 
102
+ # ์˜ˆ์ œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
103
  def generate_example_image(prompt):
104
  try:
105
  return next(generate_image(prompt, randomize_seed=True))
106
  except Exception as e:
107
+ print(f"์˜ˆ์ œ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
108
+ return None, None, f"์˜ค๋ฅ˜: {str(e)}"
109
 
110
  # Example prompts
111
  examples = [
 
117
  "A cosmic coffee shop where baristas are constellations serving drinks made of stardust"
118
  ]
119
 
120
+
121
  css = """
122
  footer {
123
  visibility: hidden;
124
  }
125
  """
126
 
127
+ # Gradio UI ๊ตฌ์„ฑ
128
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
129
  with gr.Column(elem_id="app-container"):
130
  with gr.Row():
131
  with gr.Column(scale=3):
132
+ result = gr.Image(label=labels["Generated Image"],
133
  show_label=False,
134
  interactive=False)
135
  with gr.Column(scale=1):
136
  prompt = gr.Text(
137
+ label=labels["Prompt"],
138
+ placeholder="์ƒ์„ฑํ•˜๊ณ  ์‹ถ์€ ์ด๋ฏธ์ง€๋ฅผ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”...",
139
  lines=3,
140
  show_label=False,
141
  container=False,
142
  )
143
+ enhanceBtn = gr.Button(f"๐Ÿš€ {labels['Enhance Image']}")
144
 
145
+ with gr.Column(labels["Advanced Options"]):
146
  with gr.Row():
147
  latency = gr.Text(show_label=False)
148
  with gr.Row():
 
149
  seed = gr.Number(
150
+ label=labels["Seed"],
151
  value=42,
152
  precision=0,
153
  minimum=0,
154
  maximum=MAX_SEED
155
  )
156
  randomize_seed = gr.Checkbox(
157
+ label=labels["Randomize Seed"],
158
  value=True
159
  )
160
  with gr.Row():
161
  width = gr.Slider(
162
+ label=labels["Width"],
163
  minimum=256,
164
  maximum=MAX_IMAGE_SIZE,
165
  step=32,
166
  value=DEFAULT_WIDTH
167
  )
168
  height = gr.Slider(
169
+ label=labels["Height"],
170
  minimum=256,
171
  maximum=MAX_IMAGE_SIZE,
172
  step=32,
173
  value=DEFAULT_HEIGHT
174
  )
175
  num_inference_steps = gr.Slider(
176
+ label=labels["Inference Steps"],
177
  minimum=1,
178
  maximum=4,
179
  step=1,
 
181
  )
182
 
183
  with gr.Row():
184
+ gr.Markdown(f"### ๐ŸŒŸ {labels['Inspiration Gallery']}")
185
  with gr.Row():
186
  gr.Examples(
187
  examples=examples,
 
191
  cache_examples=False
192
  )
193
 
194
+ # ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
195
+ def validated_generate(*args):
196
+ try:
197
+ return next(generate_image(*args))
198
+ except Exception as e:
199
+ print(f"๊ฒ€์ฆ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
200
+ return None, args[1], f"์˜ค๋ฅ˜: {str(e)}"
201
+
202
  enhanceBtn.click(
203
  fn=generate_image,
204
  inputs=[prompt, seed, width, height],
 
208
  queue=False
209
  )
210
 
 
 
 
 
 
 
 
 
211
  gr.on(
212
  triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
213
  fn=validated_generate,
 
220
  )
221
 
222
  if __name__ == "__main__":
223
+ demo.launch()
224
+