Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
from PIL import Image | |
import io | |
import os | |
from fal_client import submit | |
def set_fal_key(api_key): | |
os.environ["FAL_KEY"] = api_key | |
return "FAL API key set successfully!" | |
def generate_image(api_key, model, prompt, image_size, num_inference_steps, guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed): | |
set_fal_key(api_key) | |
arguments = { | |
"prompt": prompt, | |
"image_size": image_size, | |
"num_inference_steps": num_inference_steps, | |
"num_images": num_images, | |
} | |
if model == "Flux Pro": | |
arguments["guidance_scale"] = guidance_scale | |
arguments["safety_tolerance"] = safety_tolerance | |
fal_model = "fal-ai/flux-pro" | |
elif model == "Flux Dev": | |
arguments["guidance_scale"] = guidance_scale | |
arguments["enable_safety_checker"] = enable_safety_checker | |
fal_model = "fal-ai/flux/dev" | |
else: # Flux Schnell | |
arguments["enable_safety_checker"] = enable_safety_checker | |
fal_model = "fal-ai/flux/schnell" | |
if seed != -1: | |
arguments["seed"] = seed | |
try: | |
handler = submit(fal_model, arguments=arguments) | |
result = handler.get() | |
images = [] | |
for img_info in result["images"]: | |
img_url = img_info["url"] | |
img_response = requests.get(img_url) | |
img = Image.open(io.BytesIO(img_response.content)) | |
images.append(img) | |
return images | |
except Exception as e: | |
return [Image.new('RGB', (512, 512), color='black')] | |
def update_visible_components(model): | |
if model == "Flux Pro": | |
return [ | |
gr.update(visible=True, value=28), | |
gr.update(visible=True, value=3.5), | |
gr.update(visible=True, value="2"), | |
gr.update(visible=False) | |
] | |
elif model == "Flux Dev": | |
return [ | |
gr.update(visible=True, value=28), | |
gr.update(visible=True, value=3.5), | |
gr.update(visible=False), | |
gr.update(visible=True, value=True) | |
] | |
else: # Flux Schnell | |
return [ | |
gr.update(visible=True, value=4), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True, value=True) | |
] | |
with gr.Blocks(theme='Nymbo/Nymbo_Theme') as demo: | |
gr.HTML(""" | |
<h1 align="center">FLUX.1 Image Generation</h1> | |
<p align="center"> | |
<a href="https://blackforestlabs.ai/" target="_blank">[Black Forest Labs]</a> | |
<a href="https://blackforestlabs.ai/announcing-black-forest-labs/" target="_blank">[Blog]</a> | |
<a href="https://fal.ai/models/fal-ai/flux-pro" target="_blank">[FLUX.1 [pro] Model FAL]</a> | |
<a href="https://fal.ai/dashboard/keys" target="_blank">[GET YOUR API KEY HERE]</a> | |
</p> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
api_key = gr.Textbox(type="password", label="FAL API Key") | |
model = gr.Dropdown( | |
label="Model", | |
choices=["Flux Pro", "Flux Dev", "Flux Schnell"], | |
value="Flux Pro" | |
) | |
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Add your prompt here") | |
image_size = gr.Dropdown( | |
choices=["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"], | |
label="Image Size", | |
value="landscape_4_3" | |
) | |
with gr.Accordion("Advanced settings", open=False): | |
num_inference_steps = gr.Slider(1, 100, 28, step=1, label="Number of Inference Steps") | |
guidance_scale = gr.Slider(0, 20, 3.5, step=0.1, label="Guidance Scale") | |
num_images = gr.Slider(1, 10, 1, step=1, label="Number of Images") | |
safety_tolerance = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6"], label="Safety Tolerance", value="2") | |
enable_safety_checker = gr.Checkbox(label="Enable Safety Checker", value=True) | |
seed = gr.Number(label="Seed", value=-1) | |
generate_btn = gr.Button("Generate Image") | |
with gr.Column(scale=1): | |
output_gallery = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False) | |
model.change(update_visible_components, inputs=[model], outputs=[num_inference_steps, guidance_scale, safety_tolerance, enable_safety_checker]) | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[ | |
api_key, model, prompt, image_size, num_inference_steps, | |
guidance_scale, num_images, safety_tolerance, enable_safety_checker, seed | |
], | |
outputs=[output_gallery] | |
) | |
demo.launch() |