import streamlit as st
import torch
from diffusers import DiffusionPipeline

# Load both base & refiner
base = DiffusionPipeline.from_pretrained(
    "ageraustine/stable-diffusion", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
base.to("cuda")
refiner = DiffusionPipeline.from_pretrained(
    "ageraustine/stable-diffusion",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
)
refiner.to("cuda")

# Define how many steps and what % of steps to be run on each expert (80/20) here
n_steps = 40
high_noise_frac = 0.8

# Streamlit app
st.title("Text-to-Image Generation App")

# Text input
user_text = st.text_input("Enter a text prompt for image generation")

# Generate image based on user input
if st.button("Generate Image"):
    if user_text:
        # Run both experts
        image = base(
            prompt=user_text,
            num_inference_steps=n_steps,
            denoising_end=high_noise_frac,
            output_type="latent",
        ).images
        image = refiner(
            prompt=user_text,
            num_inference_steps=n_steps,
            denoising_start=high_noise_frac,
            image=image,
        ).images[0]

        # Display the generated image
        st.image(image, caption="Generated Image", use_column_width=True)
    else:
        st.warning("Please enter a text prompt for image generation.")