from openai import OpenAI
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
import os
client = OpenAI(base_url=os.environ['BASE_URL'], api_key=os.environ['API_KEY'])

def generate(prompt):
    response = client.images.generate(
        model="sdxl-lightning-4step",
        prompt=prompt
    )
    imagebytes = requests.get(response.data[0].url)
    image = Image.open(BytesIO(imagebytes.content))
    return image

with gr.Blocks() as demo:
    gr.Markdown("## SDXL Lightning Image Generator")
    with gr.Row(equal_height=True):
        promptbox = gr.Textbox(label="Prompt", placeholder="Enter your prompt")
        generatebtn = gr.Button(value="Generate", variant="primary")
    outputimg = gr.Image(width=1024, height=512)
    generatebtn.click(
        fn=generate,
        inputs=promptbox,
        outputs=outputimg
    )
demo.launch()