Hidream-E1-Full / app.py
chengzeyi's picture
fix
dc5bb05
import os
import requests
import json
import time
import threading
import uuid
import base64
from pathlib import Path
from dotenv import load_dotenv
import gradio as gr
import random
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# 环境变量加载
load_dotenv()
API_KEY = os.getenv("WAVESPEED_API_KEY")
if not API_KEY:
raise ValueError("WAVESPEED_API_KEY 未在环境变量中设置")
# 安全分类配置
MODEL_URL = "TostAI/nsfw-text-detection-large"
CLASS_NAMES = {0: "✅ SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"}
# 加载安全模型
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
except Exception as e:
raise RuntimeError(f"安全模型加载失败: {str(e)}")
# 会话管理
class SessionManager:
_instances = {}
_lock = threading.Lock()
@classmethod
def get_session(cls, session_id):
with cls._lock:
if session_id not in cls._instances:
cls._instances[session_id] = {
'count': 0,
'history': [],
'last_active': time.time()
}
return cls._instances[session_id]
@classmethod
def cleanup_sessions(cls):
with cls._lock:
now = time.time()
expired = [
k for k, v in cls._instances.items()
if now - v['last_active'] > 3600
]
for k in expired:
del cls._instances[k]
# 速率限制
class RateLimiter:
def __init__(self):
self.clients = {}
self.lock = threading.Lock()
def check(self, client_id):
with self.lock:
now = time.time()
if client_id not in self.clients:
self.clients[client_id] = {'count': 1, 'reset': now + 3600}
return True
if now > self.clients[client_id]['reset']:
self.clients[client_id] = {'count': 1, 'reset': now + 3600}
return True
if self.clients[client_id]['count'] >= 20:
return False
self.clients[client_id]['count'] += 1
return True
session_manager = SessionManager()
rate_limiter = RateLimiter()
# 工具函数
def create_error_image(message):
"""生成错误提示图片"""
img = Image.new("RGB", (512, 512), "#ffdddd")
try:
font = ImageFont.truetype("arial.ttf", 24)
except:
font = ImageFont.load_default()
draw = ImageDraw.Draw(img)
text = f"Error: {message[:60]}..." if len(message) > 60 else message
draw.text((50, 200), text, fill="#ff0000", font=font)
return img
@torch.no_grad()
def classify_prompt(prompt):
"""安全分类"""
inputs = tokenizer(prompt,
return_tensors="pt",
truncation=True,
max_length=512)
outputs = model(**inputs)
return torch.argmax(outputs.logits).item()
def image_to_base64(file_path):
"""将图片转换为Base64格式"""
with open(file_path, "rb") as f:
file_ext = Path(file_path).suffix.lower()[1:]
mime_type = f"image/{file_ext}" if file_ext in ["jpeg", "jpg", "png"
] else "image/jpeg"
return f"data:{mime_type};base64,{base64.b64encode(f.read()).decode()}"
# 核心生成逻辑
def generate_image(image_file, prompt, seed, session_id, enable_safety=True):
try:
# 安全检查
if enable_safety:
safety_level = classify_prompt(prompt)
if safety_level != 0:
error_img = create_error_image(CLASS_NAMES[safety_level])
yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img, ""
return
# 速率限制
if not rate_limiter.check(session_id):
error_img = create_error_image(
"Hourly limit exceeded (20 requests)")
yield "❌ 请求过于频繁,请稍后再试", error_img, ""
return
# 会话更新
session = session_manager.get_session(session_id)
session['last_active'] = time.time()
session['count'] += 1
# 输入验证
error_messages = []
if not image_file:
error_messages.append("请上传图片文件")
elif not Path(image_file).exists():
error_messages.append("文件不存在")
elif Path(image_file).suffix.lower()[1:] not in ["jpg", "jpeg", "png"]:
error_messages.append("仅支持JPG/PNG格式")
if not prompt.strip():
error_messages.append("提示语不能为空")
if error_messages:
error_img = create_error_image(" | ".join(error_messages))
yield "❌ 输入验证失败", error_img, ""
return
# 转换为Base64
try:
base64_image = image_to_base64(image_file)
except Exception as e:
error_img = create_error_image(f"文件处理失败: {str(e)}")
yield "❌ 文件处理失败", error_img, ""
return
# 构造请求
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}",
}
payload = {
"enable_base64_output": True,
"enable_safety_checker": enable_safety,
"image": base64_image,
"prompt": prompt,
"seed": int(seed) if seed != -1 else random.randint(0, 999999)
}
# 提交请求
response = requests.post(
"https://api.wavespeed.ai/api/v2/wavespeed-ai/hidream-e1-full",
headers=headers,
json=payload,
timeout=30)
response.raise_for_status()
# 处理响应
request_id = response.json()["data"]["id"]
result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
start_time = time.time()
# 轮询结果
for _ in range(60):
time.sleep(1)
resp = requests.get(result_url, headers=headers)
resp.raise_for_status()
data = resp.json()["data"]
status = data["status"]
if status == "completed":
elapsed = time.time() - start_time
image_url = data["outputs"][0]
session["history"].append(image_url)
yield f"🎉 生成成功! 耗时 {elapsed:.1f}s", image_url, image_url
return
elif status == "failed":
raise Exception(data.get("error", "Unknown error"))
else:
yield f"⏳ 当前状态: {status.capitalize()}...", None, None
raise Exception("生成超时")
except Exception as e:
error_img = create_error_image(str(e))
yield f"❌ 生成失败: {str(e)}", error_img, ""
# 后台清理任务
def cleanup_task():
while True:
session_manager.cleanup_sessions()
time.sleep(3600)
# 界面构建
with gr.Blocks(theme=gr.themes.Soft(),
css="""
.status-box { padding: 10px; border-radius: 5px; margin: 5px; }
.safe { background: #e8f5e9; border: 1px solid #a5d6a7; }
.warning { background: #fff3e0; border: 1px solid #ffcc80; }
.error { background: #ffebee; border: 1px solid #ef9a9a; }
""") as app:
session_id = gr.State(str(uuid.uuid4()))
gr.Markdown("# 🖼️Hidream-E1-Full Live On Wavespeed Ai")
gr.Markdown("HiDream-E1 is an image editing model built on HiDream-I1.")
with gr.Row():
with gr.Column(scale=1):
image_file = gr.Image(label="Upload Image",
type="filepath",
sources=["upload"],
interactive=True,
image_mode="RGB")
prompt = gr.Textbox(
label="prompt",
placeholder="Please enter an English prompt...",
lines=3)
seed = gr.Number(label="seed",
value=-1,
minimum=-1,
maximum=999999,
step=1)
random_btn = gr.Button("random🎲seed", variant="secondary")
enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker",
value=True,
interactive=False)
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Result")
output_url = gr.Textbox(label="image url",
interactive=True,
visible=False)
status = gr.Textbox(label="Status", elem_classes=["status-box"])
submit_btn = gr.Button("开始生成", variant="primary")
gr.Examples(examples=[
[
"Convert the image into Claymation style.",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
],
[
"Convert the image into a Ghibli style.",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg"
],
[
"Add sunglasses to the face of the girl.",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png"
],
[
'Convert the image into an ink sketch style.',
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
],
[
'Add a butterfly to the scene.',
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_result.png"
]
],
inputs=[prompt, image_file],
label="Examples")
random_btn.click(fn=lambda: random.randint(0, 999999), outputs=seed)
submit_btn.click(
generate_image,
inputs=[image_file, prompt, seed, session_id, enable_safety],
outputs=[status, output_image, output_url])
if __name__ == "__main__":
threading.Thread(target=cleanup_task, daemon=True).start()
app.queue(max_size=4).launch(server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False)