chengzeyi commited on
Commit
dc5bb05
·
1 Parent(s): da4ad68
Files changed (1) hide show
  1. app.py +68 -52
app.py CHANGED
@@ -30,6 +30,7 @@ try:
30
  except Exception as e:
31
  raise RuntimeError(f"安全模型加载失败: {str(e)}")
32
 
 
33
  # 会话管理
34
  class SessionManager:
35
  _instances = {}
@@ -50,12 +51,17 @@ class SessionManager:
50
  def cleanup_sessions(cls):
51
  with cls._lock:
52
  now = time.time()
53
- expired = [k for k, v in cls._instances.items() if now - v['last_active'] > 3600]
 
 
 
54
  for k in expired:
55
  del cls._instances[k]
56
 
 
57
  # 速率限制
58
  class RateLimiter:
 
59
  def __init__(self):
60
  self.clients = {}
61
  self.lock = threading.Lock()
@@ -74,9 +80,11 @@ class RateLimiter:
74
  self.clients[client_id]['count'] += 1
75
  return True
76
 
 
77
  session_manager = SessionManager()
78
  rate_limiter = RateLimiter()
79
 
 
80
  # 工具函数
81
  def create_error_image(message):
82
  """生成错误提示图片"""
@@ -90,28 +98,29 @@ def create_error_image(message):
90
  draw.text((50, 200), text, fill="#ff0000", font=font)
91
  return img
92
 
 
93
  @torch.no_grad()
94
  def classify_prompt(prompt):
95
  """安全分类"""
96
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
 
 
97
  outputs = model(**inputs)
98
  return torch.argmax(outputs.logits).item()
99
 
 
100
  def image_to_base64(file_path):
101
  """将图片转换为Base64格式"""
102
  with open(file_path, "rb") as f:
103
  file_ext = Path(file_path).suffix.lower()[1:]
104
- mime_type = f"image/{file_ext}" if file_ext in ["jpeg", "jpg", "png"] else "image/jpeg"
 
105
  return f"data:{mime_type};base64,{base64.b64encode(f.read()).decode()}"
106
 
 
107
  # 核心生成逻辑
108
- def generate_image(
109
- image_file,
110
- prompt,
111
- seed,
112
- session_id,
113
- enable_safety=True
114
- ):
115
  try:
116
  # 安全检查
117
  if enable_safety:
@@ -123,7 +132,8 @@ def generate_image(
123
 
124
  # 速率限制
125
  if not rate_limiter.check(session_id):
126
- error_img = create_error_image("Hourly limit exceeded (20 requests)")
 
127
  yield "❌ 请求过于频繁,请稍后再试", error_img, ""
128
  return
129
 
@@ -173,8 +183,7 @@ def generate_image(
173
  "https://api.wavespeed.ai/api/v2/wavespeed-ai/hidream-e1-full",
174
  headers=headers,
175
  json=payload,
176
- timeout=30
177
- )
178
  response.raise_for_status()
179
 
180
  # 处理响应
@@ -208,83 +217,90 @@ def generate_image(
208
  error_img = create_error_image(str(e))
209
  yield f"❌ 生成失败: {str(e)}", error_img, ""
210
 
 
211
  # 后台清理任务
212
  def cleanup_task():
213
  while True:
214
  session_manager.cleanup_sessions()
215
  time.sleep(3600)
216
 
 
217
  # 界面构建
218
- with gr.Blocks(
219
- theme=gr.themes.Soft(),
220
- css="""
221
  .status-box { padding: 10px; border-radius: 5px; margin: 5px; }
222
  .safe { background: #e8f5e9; border: 1px solid #a5d6a7; }
223
  .warning { background: #fff3e0; border: 1px solid #ffcc80; }
224
  .error { background: #ffebee; border: 1px solid #ef9a9a; }
225
- """
226
- ) as app:
227
-
228
  session_id = gr.State(str(uuid.uuid4()))
229
-
230
  gr.Markdown("# 🖼️Hidream-E1-Full Live On Wavespeed Ai")
231
  gr.Markdown("HiDream-E1 is an image editing model built on HiDream-I1.")
232
 
233
  with gr.Row():
234
  with gr.Column(scale=1):
235
- image_file = gr.Image(label="Upload Image", type="filepath", sources=["upload"], interactive=True, image_mode="RGB")
236
- prompt = gr.Textbox(label="prompt", placeholder="Please enter an English prompt...", lines=3 )
237
- seed = gr.Number(label="seed", value=-1, minimum=-1, maximum=999999, step=1)
 
 
 
 
 
 
 
 
 
 
 
238
  random_btn = gr.Button("random🎲seed", variant="secondary")
239
- enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker", value=True, interactive=False)
 
 
240
  with gr.Column(scale=1):
241
  output_image = gr.Image(label="Generated Result")
242
- output_url = gr.Textbox(label="image url", interactive=True, visible=False)
 
 
243
  status = gr.Textbox(label="Status", elem_classes=["status-box"])
244
  submit_btn = gr.Button("开始生成", variant="primary")
245
- gr.Examples(
246
- examples=[
247
  [
248
- "Editing Instruction: Convert the image into a Ghibli style. Target Image Description: The woman with long, straight hair and a black top is depicted in a Ghibli style, with vibrant colors and soft, whimsical features, set in a gently lit room.",
249
- "https://hidream-ai-hidream-e1-full.hf.space/gradio_api/file=/tmp/gradio/0d28f171c5b089688d5693d7d2bbb02d1daa86ab6c97a8ebf3199dc5f68ab3dc/test_1.png"
250
  ],
251
  [
252
- "Editing Instruction: Transform the image into a Disney Pixar style. Target Image Description: The person now has stylized, animated features typical of Disney Pixar characters, with exaggerated proportions and vibrant colors, while maintaining the same pose and setting.",
253
- "https://hidream-ai-hidream-e1-full.hf.space/gradio_api/file=/tmp/gradio/0d28f171c5b089688d5693d7d2bbb02d1daa86ab6c97a8ebf3199dc5f68ab3dc/test_1.png"
254
  ],
255
  [
256
- "Editing Instruction: Add sunglasses to the girl. Target Image Description: A girl with long brown hair is standing indoors with sunlight streaming through a window, now wearing sunglasses.",
257
- "https://hidream-ai-hidream-e1-full.hf.space/gradio_api/file=/tmp/gradio/0d28f171c5b089688d5693d7d2bbb02d1daa86ab6c97a8ebf3199dc5f68ab3dc/test_1.png"
258
  ],
259
  [
260
- 'Editing Instruction: Convert the image into an ink sketch style. Target Image Description: An ink sketch of a wooden sign with the text "HiDream E1 Coming Soon" surrounded by ivy, with simplified outlines of mountains and clouds in the background.',
261
- "https://hidream-ai-hidream-e1-full.hf.space/gradio_api/file=/tmp/gradio/6cb6be7c9020587a278d8b6d1a7be3e666a96b8f1341361671f9805bd67c2d8b/test_2.jpg"
262
  ],
263
  [
264
- 'Editing Instruction: Add a butterfly to the scene. Target Image Description: A wooden sign with the text "HiDream E1 Coming Soon" surrounded by green vines, with a butterfly added to the scene, set against a backdrop of mountains and a clear sky with fluffy clouds.',
265
- "https://hidream-ai-hidream-e1-full.hf.space/gradio_api/file=/tmp/gradio/6cb6be7c9020587a278d8b6d1a7be3e666a96b8f1341361671f9805bd67c2d8b/test_2.jpg"
266
  ]
267
  ],
268
- inputs=[prompt, image_file],
269
- label="Examples"
270
- )
271
 
272
- random_btn.click(
273
- fn=lambda: random.randint(0, 999999),
274
- outputs=seed
275
- )
276
 
277
  submit_btn.click(
278
  generate_image,
279
  inputs=[image_file, prompt, seed, session_id, enable_safety],
280
- outputs=[status, output_image, output_url]
281
- )
282
 
283
  if __name__ == "__main__":
284
  threading.Thread(target=cleanup_task, daemon=True).start()
285
- app.queue(max_size=4).launch(
286
- server_name="0.0.0.0",
287
- server_port=7860,
288
- show_error=True,
289
- share=False
290
- )
 
30
  except Exception as e:
31
  raise RuntimeError(f"安全模型加载失败: {str(e)}")
32
 
33
+
34
  # 会话管理
35
  class SessionManager:
36
  _instances = {}
 
51
  def cleanup_sessions(cls):
52
  with cls._lock:
53
  now = time.time()
54
+ expired = [
55
+ k for k, v in cls._instances.items()
56
+ if now - v['last_active'] > 3600
57
+ ]
58
  for k in expired:
59
  del cls._instances[k]
60
 
61
+
62
  # 速率限制
63
  class RateLimiter:
64
+
65
  def __init__(self):
66
  self.clients = {}
67
  self.lock = threading.Lock()
 
80
  self.clients[client_id]['count'] += 1
81
  return True
82
 
83
+
84
  session_manager = SessionManager()
85
  rate_limiter = RateLimiter()
86
 
87
+
88
  # 工具函数
89
  def create_error_image(message):
90
  """生成错误提示图片"""
 
98
  draw.text((50, 200), text, fill="#ff0000", font=font)
99
  return img
100
 
101
+
102
  @torch.no_grad()
103
  def classify_prompt(prompt):
104
  """安全分类"""
105
+ inputs = tokenizer(prompt,
106
+ return_tensors="pt",
107
+ truncation=True,
108
+ max_length=512)
109
  outputs = model(**inputs)
110
  return torch.argmax(outputs.logits).item()
111
 
112
+
113
  def image_to_base64(file_path):
114
  """将图片转换为Base64格式"""
115
  with open(file_path, "rb") as f:
116
  file_ext = Path(file_path).suffix.lower()[1:]
117
+ mime_type = f"image/{file_ext}" if file_ext in ["jpeg", "jpg", "png"
118
+ ] else "image/jpeg"
119
  return f"data:{mime_type};base64,{base64.b64encode(f.read()).decode()}"
120
 
121
+
122
  # 核心生成逻辑
123
+ def generate_image(image_file, prompt, seed, session_id, enable_safety=True):
 
 
 
 
 
 
124
  try:
125
  # 安全检查
126
  if enable_safety:
 
132
 
133
  # 速率限制
134
  if not rate_limiter.check(session_id):
135
+ error_img = create_error_image(
136
+ "Hourly limit exceeded (20 requests)")
137
  yield "❌ 请求过于频繁,请稍后再试", error_img, ""
138
  return
139
 
 
183
  "https://api.wavespeed.ai/api/v2/wavespeed-ai/hidream-e1-full",
184
  headers=headers,
185
  json=payload,
186
+ timeout=30)
 
187
  response.raise_for_status()
188
 
189
  # 处理响应
 
217
  error_img = create_error_image(str(e))
218
  yield f"❌ 生成失败: {str(e)}", error_img, ""
219
 
220
+
221
  # 后台清理任务
222
  def cleanup_task():
223
  while True:
224
  session_manager.cleanup_sessions()
225
  time.sleep(3600)
226
 
227
+
228
  # 界面构建
229
+ with gr.Blocks(theme=gr.themes.Soft(),
230
+ css="""
 
231
  .status-box { padding: 10px; border-radius: 5px; margin: 5px; }
232
  .safe { background: #e8f5e9; border: 1px solid #a5d6a7; }
233
  .warning { background: #fff3e0; border: 1px solid #ffcc80; }
234
  .error { background: #ffebee; border: 1px solid #ef9a9a; }
235
+ """) as app:
236
+
 
237
  session_id = gr.State(str(uuid.uuid4()))
238
+
239
  gr.Markdown("# 🖼️Hidream-E1-Full Live On Wavespeed Ai")
240
  gr.Markdown("HiDream-E1 is an image editing model built on HiDream-I1.")
241
 
242
  with gr.Row():
243
  with gr.Column(scale=1):
244
+ image_file = gr.Image(label="Upload Image",
245
+ type="filepath",
246
+ sources=["upload"],
247
+ interactive=True,
248
+ image_mode="RGB")
249
+ prompt = gr.Textbox(
250
+ label="prompt",
251
+ placeholder="Please enter an English prompt...",
252
+ lines=3)
253
+ seed = gr.Number(label="seed",
254
+ value=-1,
255
+ minimum=-1,
256
+ maximum=999999,
257
+ step=1)
258
  random_btn = gr.Button("random🎲seed", variant="secondary")
259
+ enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker",
260
+ value=True,
261
+ interactive=False)
262
  with gr.Column(scale=1):
263
  output_image = gr.Image(label="Generated Result")
264
+ output_url = gr.Textbox(label="image url",
265
+ interactive=True,
266
+ visible=False)
267
  status = gr.Textbox(label="Status", elem_classes=["status-box"])
268
  submit_btn = gr.Button("开始生成", variant="primary")
269
+ gr.Examples(examples=[
 
270
  [
271
+ "Convert the image into Claymation style.",
272
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
273
  ],
274
  [
275
+ "Convert the image into a Ghibli style.",
276
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg"
277
  ],
278
  [
279
+ "Add sunglasses to the face of the girl.",
280
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png"
281
  ],
282
  [
283
+ 'Convert the image into an ink sketch style.',
284
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
285
  ],
286
  [
287
+ 'Add a butterfly to the scene.',
288
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_result.png"
289
  ]
290
  ],
291
+ inputs=[prompt, image_file],
292
+ label="Examples")
 
293
 
294
+ random_btn.click(fn=lambda: random.randint(0, 999999), outputs=seed)
 
 
 
295
 
296
  submit_btn.click(
297
  generate_image,
298
  inputs=[image_file, prompt, seed, session_id, enable_safety],
299
+ outputs=[status, output_image, output_url])
 
300
 
301
  if __name__ == "__main__":
302
  threading.Thread(target=cleanup_task, daemon=True).start()
303
+ app.queue(max_size=4).launch(server_name="0.0.0.0",
304
+ server_port=7860,
305
+ show_error=True,
306
+ share=False)