BSPLow-Work / app.py
neo7team's picture
Update app.py
f94893e verified
import gradio as gr
import requests
import io
import random
import os
import time
from PIL import Image
from deep_translator import GoogleTranslator
# Project by Nymbo
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
API_TOKEN = os.getenv("HF_READ_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"}
timeout = 100
def convert_to_png(image):
"""Convert any image format to true PNG format"""
png_buffer = io.BytesIO()
if image.mode == 'RGBA':
# If image has alpha channel, save as PNG with transparency
image.save(png_buffer, format='PNG', optimize=True)
else:
# Convert to RGB first if not in RGB/RGBA mode
if image.mode != 'RGB':
image = image.convert('RGB')
image.save(png_buffer, format='PNG', optimize=True)
png_buffer.seek(0)
return Image.open(png_buffer)
def query(prompt, is_negative=False, steps=4, cfg_scale=7, sampler="DPM++ 2M Karras",
seed=-1, strength=0.7, width=1024, height=1024):
if not prompt:
return None
key = random.randint(0, 999)
API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN")])
headers = {"Authorization": f"Bearer {API_TOKEN}"}
payload = {
"inputs": prompt,
"is_negative": is_negative,
"steps": steps,
"cfg_scale": cfg_scale,
"seed": seed if seed != -1 else random.randint(1, 1000000000),
"strength": strength,
"parameters": {"width": width, "height": height}
}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
response.raise_for_status()
# Convert directly to PNG without intermediate format
img = Image.open(io.BytesIO(response.content))
png_img = convert_to_png(img)
print(f'\033[1mGeneration {key} completed as PNG!\033[0m')
return png_img
except requests.exceptions.RequestException as e:
print(f"API Error: {e}")
if hasattr(e, 'response') and e.response:
if e.response.status_code == 503:
raise gr.Error("503: Model is loading, please try again later")
raise gr.Error(f"{e.response.status_code}: {e.response.text}")
raise gr.Error("Network error occurred")
except Exception as e:
print(f"Image processing error: {e}")
raise gr.Error(f"Image processing failed: {str(e)}")
# Light theme CSS
css = """
#app-container {
max-width: 800px;
margin: 0 auto;
padding: 20px;
background: #ffffff;
}
#prompt-text-input, #negative-prompt-text-input {
font-size: 14px;
background: #f9f9f9;
}
#gallery {
min-height: 512px;
background: #ffffff;
border: 1px solid #e0e0e0;
}
#gen-button {
margin: 10px 0;
background: #4CAF50;
color: white;
}
.accordion {
background: #f5f5f5;
border: 1px solid #e0e0e0;
}
h1 {
color: #333333;
}
"""
with gr.Blocks(theme=gr.themes.Default(primary_hue="green"), css=css) as app:
gr.HTML("<center><h1>BSP Low Work</h1></center>")
with gr.Column(elem_id="app-container"):
with gr.Row():
with gr.Column(elem_id="prompt-container"):
with gr.Row():
text_prompt = gr.Textbox(
label="Prompt",
placeholder="Prompt",
lines=2,
elem_id="prompt-text-input"
)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
lines=3
)
with gr.Row():
width = gr.Slider(1024, label="Width", minimum=512, maximum=2048, step=64)
height = gr.Slider(1024, label="Height", minimum=512, maximum=2048, step=64)
with gr.Row():
steps = gr.Slider(4, label="Steps", minimum=4, maximum=100, step=1)
cfg = gr.Slider(7.0, label="CFG Scale", minimum=1.0, maximum=20.0, step=0.5)
with gr.Row():
strength = gr.Slider(0.7, label="Strength", minimum=0.1, maximum=1.0, step=0.01)
seed = gr.Number(-1, label="Seed (-1 for random)")
method = gr.Radio(
["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"],
value="DPM++ 2M Karras",
label="Sampling Method"
)
generate_btn = gr.Button("Generate Image", variant="primary")
with gr.Row():
output_image = gr.Image(
type="pil",
label="Generated Image",
format="png", # Explicitly set output format
elem_id="gallery"
)
generate_btn.click(
fn=query,
inputs=[text_prompt, negative_prompt, steps, cfg, method, seed, strength, width, height],
outputs=output_image
)
app.launch(server_name="0.0.0.0", server_port=7860, share=True)