Spaces:
Paused
Paused
import os | |
import requests | |
import json | |
import time | |
import threading | |
import uuid | |
import base64 | |
from pathlib import Path | |
from dotenv import load_dotenv | |
import gradio as gr | |
import random | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# 环境变量加载 | |
load_dotenv() | |
API_KEY = os.getenv("WAVESPEED_API_KEY") | |
if not API_KEY: | |
raise ValueError("WAVESPEED_API_KEY 未在环境变量中设置") | |
# 安全分类配置 | |
MODEL_URL = "TostAI/nsfw-text-detection-large" | |
CLASS_NAMES = {0: "✅ SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"} | |
# 加载安全模型 | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL) | |
except Exception as e: | |
raise RuntimeError(f"安全模型加载失败: {str(e)}") | |
# 会话管理 | |
class SessionManager: | |
_instances = {} | |
_lock = threading.Lock() | |
def get_session(cls, session_id): | |
with cls._lock: | |
if session_id not in cls._instances: | |
cls._instances[session_id] = { | |
'count': 0, | |
'history': [], | |
'last_active': time.time() | |
} | |
return cls._instances[session_id] | |
def cleanup_sessions(cls): | |
with cls._lock: | |
now = time.time() | |
expired = [ | |
k for k, v in cls._instances.items() | |
if now - v['last_active'] > 3600 | |
] | |
for k in expired: | |
del cls._instances[k] | |
# 速率限制 | |
class RateLimiter: | |
def __init__(self): | |
self.clients = {} | |
self.lock = threading.Lock() | |
def check(self, client_id): | |
with self.lock: | |
now = time.time() | |
if client_id not in self.clients: | |
self.clients[client_id] = {'count': 1, 'reset': now + 3600} | |
return True | |
if now > self.clients[client_id]['reset']: | |
self.clients[client_id] = {'count': 1, 'reset': now + 3600} | |
return True | |
if self.clients[client_id]['count'] >= 20: | |
return False | |
self.clients[client_id]['count'] += 1 | |
return True | |
session_manager = SessionManager() | |
rate_limiter = RateLimiter() | |
# 工具函数 | |
def create_error_image(message): | |
"""生成错误提示图片""" | |
img = Image.new("RGB", (512, 512), "#ffdddd") | |
try: | |
font = ImageFont.truetype("arial.ttf", 24) | |
except: | |
font = ImageFont.load_default() | |
draw = ImageDraw.Draw(img) | |
text = f"Error: {message[:60]}..." if len(message) > 60 else message | |
draw.text((50, 200), text, fill="#ff0000", font=font) | |
return img | |
def classify_prompt(prompt): | |
"""安全分类""" | |
inputs = tokenizer(prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512) | |
outputs = model(**inputs) | |
return torch.argmax(outputs.logits).item() | |
def image_to_base64(file_path): | |
"""将图片转换为Base64格式""" | |
with open(file_path, "rb") as f: | |
file_ext = Path(file_path).suffix.lower()[1:] | |
mime_type = f"image/{file_ext}" if file_ext in ["jpeg", "jpg", "png" | |
] else "image/jpeg" | |
return f"data:{mime_type};base64,{base64.b64encode(f.read()).decode()}" | |
# 核心生成逻辑 | |
def generate_image(image_file, prompt, seed, session_id, enable_safety=True): | |
try: | |
# 安全检查 | |
if enable_safety: | |
safety_level = classify_prompt(prompt) | |
if safety_level != 0: | |
error_img = create_error_image(CLASS_NAMES[safety_level]) | |
yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img, "" | |
return | |
# 速率限制 | |
if not rate_limiter.check(session_id): | |
error_img = create_error_image( | |
"Hourly limit exceeded (20 requests)") | |
yield "❌ 请求过于频繁,请稍后再试", error_img, "" | |
return | |
# 会话更新 | |
session = session_manager.get_session(session_id) | |
session['last_active'] = time.time() | |
session['count'] += 1 | |
# 输入验证 | |
error_messages = [] | |
if not image_file: | |
error_messages.append("请上传图片文件") | |
elif not Path(image_file).exists(): | |
error_messages.append("文件不存在") | |
elif Path(image_file).suffix.lower()[1:] not in ["jpg", "jpeg", "png"]: | |
error_messages.append("仅支持JPG/PNG格式") | |
if not prompt.strip(): | |
error_messages.append("提示语不能为空") | |
if error_messages: | |
error_img = create_error_image(" | ".join(error_messages)) | |
yield "❌ 输入验证失败", error_img, "" | |
return | |
# 转换为Base64 | |
try: | |
base64_image = image_to_base64(image_file) | |
except Exception as e: | |
error_img = create_error_image(f"文件处理失败: {str(e)}") | |
yield "❌ 文件处理失败", error_img, "" | |
return | |
# 构造请求 | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}", | |
} | |
payload = { | |
"enable_base64_output": True, | |
"enable_safety_checker": enable_safety, | |
"image": base64_image, | |
"prompt": prompt, | |
"seed": int(seed) if seed != -1 else random.randint(0, 999999) | |
} | |
# 提交请求 | |
response = requests.post( | |
"https://api.wavespeed.ai/api/v2/wavespeed-ai/hidream-e1-full", | |
headers=headers, | |
json=payload, | |
timeout=30) | |
response.raise_for_status() | |
# 处理响应 | |
request_id = response.json()["data"]["id"] | |
result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result" | |
start_time = time.time() | |
# 轮询结果 | |
for _ in range(60): | |
time.sleep(1) | |
resp = requests.get(result_url, headers=headers) | |
resp.raise_for_status() | |
data = resp.json()["data"] | |
status = data["status"] | |
if status == "completed": | |
elapsed = time.time() - start_time | |
image_url = data["outputs"][0] | |
session["history"].append(image_url) | |
yield f"🎉 生成成功! 耗时 {elapsed:.1f}s", image_url, image_url | |
return | |
elif status == "failed": | |
raise Exception(data.get("error", "Unknown error")) | |
else: | |
yield f"⏳ 当前状态: {status.capitalize()}...", None, None | |
raise Exception("生成超时") | |
except Exception as e: | |
error_img = create_error_image(str(e)) | |
yield f"❌ 生成失败: {str(e)}", error_img, "" | |
# 后台清理任务 | |
def cleanup_task(): | |
while True: | |
session_manager.cleanup_sessions() | |
time.sleep(3600) | |
# 界面构建 | |
with gr.Blocks(theme=gr.themes.Soft(), | |
css=""" | |
.status-box { padding: 10px; border-radius: 5px; margin: 5px; } | |
.safe { background: #e8f5e9; border: 1px solid #a5d6a7; } | |
.warning { background: #fff3e0; border: 1px solid #ffcc80; } | |
.error { background: #ffebee; border: 1px solid #ef9a9a; } | |
""") as app: | |
session_id = gr.State(str(uuid.uuid4())) | |
gr.Markdown("# 🖼️Hidream-E1-Full Live On Wavespeed Ai") | |
gr.Markdown("HiDream-E1 is an image editing model built on HiDream-I1.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_file = gr.Image(label="Upload Image", | |
type="filepath", | |
sources=["upload"], | |
interactive=True, | |
image_mode="RGB") | |
prompt = gr.Textbox( | |
label="prompt", | |
placeholder="Please enter an English prompt...", | |
lines=3) | |
seed = gr.Number(label="seed", | |
value=-1, | |
minimum=-1, | |
maximum=999999, | |
step=1) | |
random_btn = gr.Button("random🎲seed", variant="secondary") | |
enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker", | |
value=True, | |
interactive=False) | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="Generated Result") | |
output_url = gr.Textbox(label="image url", | |
interactive=True, | |
visible=False) | |
status = gr.Textbox(label="Status", elem_classes=["status-box"]) | |
submit_btn = gr.Button("开始生成", variant="primary") | |
gr.Examples(examples=[ | |
[ | |
"Convert the image into Claymation style.", | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" | |
], | |
[ | |
"Convert the image into a Ghibli style.", | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg" | |
], | |
[ | |
"Add sunglasses to the face of the girl.", | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png" | |
], | |
[ | |
'Convert the image into an ink sketch style.', | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" | |
], | |
[ | |
'Add a butterfly to the scene.', | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_result.png" | |
] | |
], | |
inputs=[prompt, image_file], | |
label="Examples") | |
random_btn.click(fn=lambda: random.randint(0, 999999), outputs=seed) | |
submit_btn.click( | |
generate_image, | |
inputs=[image_file, prompt, seed, session_id, enable_safety], | |
outputs=[status, output_image, output_url]) | |
if __name__ == "__main__": | |
threading.Thread(target=cleanup_task, daemon=True).start() | |
app.queue(max_size=4).launch(server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
share=False) | |