Spaces:
Running
Running
import gradio as gr | |
import requests | |
import io | |
import random | |
import os | |
import time | |
from PIL import Image | |
from deep_translator import GoogleTranslator | |
# Project by Nymbo | |
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev" | |
API_TOKEN = os.getenv("HF_READ_TOKEN") | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
timeout = 9000000 | |
def convert_to_png(image): | |
"""Convert any image format to true PNG format""" | |
png_buffer = io.BytesIO() | |
if image.mode == 'RGBA': | |
# If image has alpha channel, save as PNG with transparency | |
image.save(png_buffer, format='PNG', optimize=True) | |
else: | |
# Convert to RGB first if not in RGB/RGBA mode | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image.save(png_buffer, format='PNG', optimize=True) | |
png_buffer.seek(0) | |
return Image.open(png_buffer) | |
def query(prompt, is_negative=False, steps=20, cfg_scale=7, sampler="DPM++ 2M Karras", | |
seed=-1, strength=0.7, width=1024, height=1024): | |
if not prompt: | |
return None | |
key = random.randint(0, 999) | |
API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN")]) | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
payload = { | |
"inputs": prompt, | |
"is_negative": is_negative, | |
"steps": steps, | |
"cfg_scale": cfg_scale, | |
"seed": seed if seed != -1 else random.randint(1, 1000000000), | |
"strength": strength, | |
"parameters": {"width": width, "height": height} | |
} | |
try: | |
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout) | |
response.raise_for_status() | |
# Convert directly to PNG without intermediate format | |
img = Image.open(io.BytesIO(response.content)) | |
png_img = convert_to_png(img) | |
print(f'\033[1mGeneration {key} completed as PNG!\033[0m') | |
return png_img | |
except requests.exceptions.RequestException as e: | |
print(f"API Error: {e}") | |
if hasattr(e, 'response') and e.response: | |
if e.response.status_code == 503: | |
raise gr.Error("503: Model is loading, please try again later") | |
raise gr.Error(f"{e.response.status_code}: {e.response.text}") | |
raise gr.Error("Network error occurred") | |
except Exception as e: | |
print(f"Image processing error: {e}") | |
raise gr.Error(f"Image processing failed: {str(e)}") | |
# Light theme CSS | |
css = """ | |
#app-container { | |
max-width: 800px; | |
margin: 0 auto; | |
padding: 20px; | |
background: #ffffff; | |
} | |
#prompt-text-input, #negative-prompt-text-input { | |
font-size: 14px; | |
background: #f9f9f9; | |
} | |
#gallery { | |
min-height: 512px; | |
background: #ffffff; | |
border: 1px solid #e0e0e0; | |
} | |
#gen-button { | |
margin: 10px 0; | |
background: #4CAF50; | |
color: white; | |
} | |
.accordion { | |
background: #f5f5f5; | |
border: 1px solid #e0e0e0; | |
} | |
h1 { | |
color: #333333; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Default(primary_hue="green"), css=css) as app: | |
gr.HTML("<center><h1>BSP Dev Work</h1></center>") | |
with gr.Column(elem_id="app-container"): | |
with gr.Row(): | |
with gr.Column(elem_id="prompt-container"): | |
with gr.Row(): | |
text_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Prompt", | |
lines=2, | |
elem_id="prompt-text-input" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation", | |
lines=3 | |
) | |
with gr.Row(): | |
width = gr.Slider(1024, label="Width", minimum=512, maximum=2048, step=64) | |
height = gr.Slider(1024, label="Height", minimum=512, maximum=2048, step=64) | |
with gr.Row(): | |
steps = gr.Slider(4, label="Steps", minimum=4, maximum=100, step=1) | |
cfg = gr.Slider(7.0, label="CFG Scale", minimum=1.0, maximum=20.0, step=0.5) | |
with gr.Row(): | |
strength = gr.Slider(0.7, label="Strength", minimum=0.1, maximum=1.0, step=0.01) | |
seed = gr.Number(-1, label="Seed (-1 for random)") | |
method = gr.Radio( | |
["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"], | |
value="DPM++ 2M Karras", | |
label="Sampling Method" | |
) | |
generate_btn = gr.Button("Generate Image", variant="primary") | |
with gr.Row(): | |
output_image = gr.Image( | |
type="pil", | |
label="Generated PNG Image", | |
format="png", # Explicitly set output format | |
elem_id="gallery" | |
) | |
generate_btn.click( | |
fn=query, | |
inputs=[text_prompt, negative_prompt, steps, cfg, method, seed, strength, width, height], | |
outputs=output_image | |
) | |
app.launch(server_name="0.0.0.0", server_port=7860, share=True) |