Spaces:
Running
Running
import gradio as gr | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from diffusers import LDMSuperResolutionPipeline | |
import torch | |
import numpy as np | |
# Device configuration | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_id = "CompVis/ldm-super-resolution-4x-openimages" | |
# Load model | |
try: | |
pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id) | |
pipeline = pipeline.to(device) | |
except Exception as e: | |
print(f"Model loading error: {e}") | |
pipeline = None | |
def super_resolve_image(input_image, num_steps=50): | |
if input_image is None or pipeline is None: | |
return None | |
try: | |
# Ensure input is PIL Image | |
if not isinstance(input_image, Image.Image): | |
input_image = Image.fromarray(input_image) | |
# Resize input to 128x128 if needed | |
input_image = input_image.resize((128, 128), Image.LANCZOS) | |
# Ensure image is RGB | |
input_image = input_image.convert("RGB") | |
# Run super resolution | |
upscaled_image = pipeline( | |
input_image, | |
num_inference_steps=num_steps, | |
eta=1 | |
).images[0] | |
return np.array(upscaled_image) | |
except Exception as e: | |
print(f"Super-resolution error: {e}") | |
return None | |
def create_gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🖼️ LDM Super Resolution") | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
output_image = gr.Image(label="Super-Resolved Image") | |
with gr.Row(): | |
num_steps = gr.Slider( | |
minimum=10, | |
maximum=200, | |
value=50, | |
label="Inference Steps" | |
) | |
enhance_btn = gr.Button("Enhance Image Resolution") | |
enhance_btn.click( | |
fn=super_resolve_image, | |
inputs=[input_image, num_steps], | |
outputs=output_image | |
) | |
# Example images for quick testing | |
gr.Examples( | |
examples=[ | |
"https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png" | |
], | |
inputs=input_image | |
) | |
return demo | |
# Launch the interface | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |