import os import requests import json import time import threading import uuid import base64 from pathlib import Path from dotenv import load_dotenv import gradio as gr import random import torch from PIL import Image, ImageDraw, ImageFont from transformers import AutoTokenizer, AutoModelForSequenceClassification # 环境变量加载 load_dotenv() API_KEY = os.getenv("WAVESPEED_API_KEY") if not API_KEY: raise ValueError("WAVESPEED_API_KEY 未在环境变量中设置") # 安全分类配置 MODEL_URL = "TostAI/nsfw-text-detection-large" CLASS_NAMES = {0: "✅ SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"} # 加载安全模型 try: tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL) except Exception as e: raise RuntimeError(f"安全模型加载失败: {str(e)}") # 会话管理 class SessionManager: _instances = {} _lock = threading.Lock() @classmethod def get_session(cls, session_id): with cls._lock: if session_id not in cls._instances: cls._instances[session_id] = { 'count': 0, 'history': [], 'last_active': time.time() } return cls._instances[session_id] @classmethod def cleanup_sessions(cls): with cls._lock: now = time.time() expired = [ k for k, v in cls._instances.items() if now - v['last_active'] > 3600 ] for k in expired: del cls._instances[k] # 速率限制 class RateLimiter: def __init__(self): self.clients = {} self.lock = threading.Lock() def check(self, client_id): with self.lock: now = time.time() if client_id not in self.clients: self.clients[client_id] = {'count': 1, 'reset': now + 3600} return True if now > self.clients[client_id]['reset']: self.clients[client_id] = {'count': 1, 'reset': now + 3600} return True if self.clients[client_id]['count'] >= 20: return False self.clients[client_id]['count'] += 1 return True session_manager = SessionManager() rate_limiter = RateLimiter() # 工具函数 def create_error_image(message): """生成错误提示图片""" img = Image.new("RGB", (512, 512), "#ffdddd") try: font = ImageFont.truetype("arial.ttf", 24) except: font = ImageFont.load_default() draw = ImageDraw.Draw(img) text = f"Error: {message[:60]}..." if len(message) > 60 else message draw.text((50, 200), text, fill="#ff0000", font=font) return img @torch.no_grad() def classify_prompt(prompt): """安全分类""" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) outputs = model(**inputs) return torch.argmax(outputs.logits).item() def image_to_base64(file_path): """将图片转换为Base64格式""" with open(file_path, "rb") as f: file_ext = Path(file_path).suffix.lower()[1:] mime_type = f"image/{file_ext}" if file_ext in ["jpeg", "jpg", "png" ] else "image/jpeg" return f"data:{mime_type};base64,{base64.b64encode(f.read()).decode()}" # 核心生成逻辑 def generate_image(image_file, prompt, seed, session_id, enable_safety=True): try: # 安全检查 if enable_safety: safety_level = classify_prompt(prompt) if safety_level != 0: error_img = create_error_image(CLASS_NAMES[safety_level]) yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img, "" return # 速率限制 if not rate_limiter.check(session_id): error_img = create_error_image( "Hourly limit exceeded (20 requests)") yield "❌ 请求过于频繁,请稍后再试", error_img, "" return # 会话更新 session = session_manager.get_session(session_id) session['last_active'] = time.time() session['count'] += 1 # 输入验证 error_messages = [] if not image_file: error_messages.append("请上传图片文件") elif not Path(image_file).exists(): error_messages.append("文件不存在") elif Path(image_file).suffix.lower()[1:] not in ["jpg", "jpeg", "png"]: error_messages.append("仅支持JPG/PNG格式") if not prompt.strip(): error_messages.append("提示语不能为空") if error_messages: error_img = create_error_image(" | ".join(error_messages)) yield "❌ 输入验证失败", error_img, "" return # 转换为Base64 try: base64_image = image_to_base64(image_file) except Exception as e: error_img = create_error_image(f"文件处理失败: {str(e)}") yield "❌ 文件处理失败", error_img, "" return # 构造请求 headers = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}", } payload = { "enable_base64_output": True, "enable_safety_checker": enable_safety, "image": base64_image, "prompt": prompt, "seed": int(seed) if seed != -1 else random.randint(0, 999999) } # 提交请求 response = requests.post( "https://api.wavespeed.ai/api/v2/wavespeed-ai/hidream-e1-full", headers=headers, json=payload, timeout=30) response.raise_for_status() # 处理响应 request_id = response.json()["data"]["id"] result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result" start_time = time.time() # 轮询结果 for _ in range(60): time.sleep(1) resp = requests.get(result_url, headers=headers) resp.raise_for_status() data = resp.json()["data"] status = data["status"] if status == "completed": elapsed = time.time() - start_time image_url = data["outputs"][0] session["history"].append(image_url) yield f"🎉 生成成功! 耗时 {elapsed:.1f}s", image_url, image_url return elif status == "failed": raise Exception(data.get("error", "Unknown error")) else: yield f"⏳ 当前状态: {status.capitalize()}...", None, None raise Exception("生成超时") except Exception as e: error_img = create_error_image(str(e)) yield f"❌ 生成失败: {str(e)}", error_img, "" # 后台清理任务 def cleanup_task(): while True: session_manager.cleanup_sessions() time.sleep(3600) # 界面构建 with gr.Blocks(theme=gr.themes.Soft(), css=""" .status-box { padding: 10px; border-radius: 5px; margin: 5px; } .safe { background: #e8f5e9; border: 1px solid #a5d6a7; } .warning { background: #fff3e0; border: 1px solid #ffcc80; } .error { background: #ffebee; border: 1px solid #ef9a9a; } """) as app: session_id = gr.State(str(uuid.uuid4())) gr.Markdown("# 🖼️Hidream-E1-Full Live On Wavespeed Ai") gr.Markdown("HiDream-E1 is an image editing model built on HiDream-I1.") with gr.Row(): with gr.Column(scale=1): image_file = gr.Image(label="Upload Image", type="filepath", sources=["upload"], interactive=True, image_mode="RGB") prompt = gr.Textbox( label="prompt", placeholder="Please enter an English prompt...", lines=3) seed = gr.Number(label="seed", value=-1, minimum=-1, maximum=999999, step=1) random_btn = gr.Button("random🎲seed", variant="secondary") enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker", value=True, interactive=False) with gr.Column(scale=1): output_image = gr.Image(label="Generated Result") output_url = gr.Textbox(label="image url", interactive=True, visible=False) status = gr.Textbox(label="Status", elem_classes=["status-box"]) submit_btn = gr.Button("开始生成", variant="primary") gr.Examples(examples=[ [ "Convert the image into Claymation style.", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" ], [ "Convert the image into a Ghibli style.", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg" ], [ "Add sunglasses to the face of the girl.", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png" ], [ 'Convert the image into an ink sketch style.', "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" ], [ 'Add a butterfly to the scene.', "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_result.png" ] ], inputs=[prompt, image_file], label="Examples") random_btn.click(fn=lambda: random.randint(0, 999999), outputs=seed) submit_btn.click( generate_image, inputs=[image_file, prompt, seed, session_id, enable_safety], outputs=[status, output_image, output_url]) if __name__ == "__main__": threading.Thread(target=cleanup_task, daemon=True).start() app.queue(max_size=4).launch(server_name="0.0.0.0", server_port=7860, show_error=True, share=False)