import streamlit as st
import torch
from diffusers import StableDiffusionPipeline
from transformers import pipeline, set_seed
from PIL import Image

# TTI Class Definition
class TTI:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    seed = 42
    generator = torch.Generator(device).manual_seed(seed)
    image_gen_steps = 35
    image_gen_size = (400, 400)
    image_gen_guidence_scale = 9
    image_gen_model_id = "stabilityai/stable-diffusion-2"
    prompt_gen_model_id = "gpt2"

# Load Stable Diffusion Model
@st.cache_resource
def load_image_gen_model():
    model = StableDiffusionPipeline.from_pretrained(
        TTI.image_gen_model_id,
        torch_dtype=torch.float16,
        revision="fp16"
    )
    return model.to(TTI.device)

image_gen_model = load_image_gen_model()

# Function to Generate Images
def generate_image(prompt, model):
    image = model(
        prompt,
        num_inference_steps=TTI.image_gen_steps,
        generator=TTI.generator,
        guidance_scale=TTI.image_gen_guidence_scale
    ).images[0]
    # Resize the image to the specified size
    image = image.resize(TTI.image_gen_size, Image.ANTIALIAS)
    return image

# Streamlit UI
st.title("Text-to-Image Generator")
st.write("Generate images from text prompts using Stable Diffusion.")

# User Input: Prompt
prompt = st.text_input("Enter a text prompt", value="A monkey on a tree")

# User Input: Inference Steps
image_gen_steps = st.slider(
    "Number of inference steps (Higher = Better quality but slower)",
    min_value=10,
    max_value=100,
    value=TTI.image_gen_steps,
    step=5
)

# User Input: Guidance Scale
guidance_scale = st.slider(
    "Guidance scale (Higher = Closer to prompt, but less creative)",
    min_value=1.0,
    max_value=20.0,
    value=float(TTI.image_gen_guidence_scale), # Convert the value to float
    step=0.5
)

# User Input: Image Size
image_width = st.number_input("Image Width", min_value=64, max_value=1024, value=TTI.image_gen_size[0], step=64)
image_height = st.number_input("Image Height", min_value=64, max_value=1024, value=TTI.image_gen_size[1], step=64)

# Generate Image Button
if st.button("Generate Image"):
    TTI.image_gen_steps = image_gen_steps
    TTI.image_gen_guidence_scale = guidance_scale
    TTI.image_gen_size = (image_width, image_height)
    with st.spinner("Generating image..."):
        image = generate_image(prompt, image_gen_model)
        st.image(image, caption=f"Generated Image for Prompt: '{prompt}'", use_column_width=True)

st.write("Adjust parameters to customize the image generation!")