aromidvar's picture
Update app.py
945b29b verified
raw
history blame contribute delete
2.64 kB
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from diffusers import LDMSuperResolutionPipeline
import torch
import numpy as np
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "CompVis/ldm-super-resolution-4x-openimages"
# Load model
try:
pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id)
pipeline = pipeline.to(device)
except Exception as e:
print(f"Model loading error: {e}")
pipeline = None
def super_resolve_image(input_image, num_steps=50):
if input_image is None or pipeline is None:
return None
try:
# Ensure input is PIL Image
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
# Resize input to 128x128 if needed
input_image = input_image.resize((128, 128), Image.LANCZOS)
# Ensure image is RGB
input_image = input_image.convert("RGB")
# Run super resolution
upscaled_image = pipeline(
input_image,
num_inference_steps=num_steps,
eta=1
).images[0]
return np.array(upscaled_image)
except Exception as e:
print(f"Super-resolution error: {e}")
return None
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# 🖼️ LDM Super Resolution")
with gr.Row():
input_image = gr.Image(label="Input Image", type="pil")
output_image = gr.Image(label="Super-Resolved Image")
with gr.Row():
num_steps = gr.Slider(
minimum=10,
maximum=200,
value=50,
label="Inference Steps"
)
enhance_btn = gr.Button("Enhance Image Resolution")
enhance_btn.click(
fn=super_resolve_image,
inputs=[input_image, num_steps],
outputs=output_image
)
# Example images for quick testing
gr.Examples(
examples=[
"https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png"
],
inputs=input_image
)
return demo
# Launch the interface
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860
)