File size: 2,642 Bytes
9dc758e
945b29b
9dc758e
945b29b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc758e
945b29b
 
9dc758e
 
 
945b29b
 
 
9dc758e
945b29b
 
9dc758e
945b29b
 
9dc758e
945b29b
 
 
 
 
 
1d87f91
945b29b
9dc758e
 
 
 
 
 
 
945b29b
9dc758e
 
 
945b29b
aa8b76a
 
945b29b
 
 
 
 
aa8b76a
9dc758e
 
 
 
 
945b29b
9dc758e
 
945b29b
 
 
 
 
 
 
 
9dc758e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr  
import requests  
from PIL import Image  
from io import BytesIO  
from diffusers import LDMSuperResolutionPipeline  
import torch  
import numpy as np  

# Device configuration  
device = "cuda" if torch.cuda.is_available() else "cpu"  
model_id = "CompVis/ldm-super-resolution-4x-openimages"  

# Load model  
try:  
    pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id)  
    pipeline = pipeline.to(device)  
except Exception as e:  
    print(f"Model loading error: {e}")  
    pipeline = None  

def super_resolve_image(input_image, num_steps=50):  
    if input_image is None or pipeline is None:  
        return None  
    
    try:  
        # Ensure input is PIL Image  
        if not isinstance(input_image, Image.Image):  
            input_image = Image.fromarray(input_image)  
        
        # Resize input to 128x128 if needed  
        input_image = input_image.resize((128, 128), Image.LANCZOS)  
        
        # Ensure image is RGB  
        input_image = input_image.convert("RGB")  
        
        # Run super resolution  
        upscaled_image = pipeline(  
            input_image,   
            num_inference_steps=num_steps,   
            eta=1  
        ).images[0]  
        
        return np.array(upscaled_image)  
    
    except Exception as e:  
        print(f"Super-resolution error: {e}")  
        return None  

def create_gradio_interface():  
    with gr.Blocks() as demo:  
        gr.Markdown("# 🖼️ LDM Super Resolution")  
        
        with gr.Row():  
            input_image = gr.Image(label="Input Image", type="pil")  
            output_image = gr.Image(label="Super-Resolved Image")  
        
        with gr.Row():  
            num_steps = gr.Slider(  
                minimum=10,   
                maximum=200,   
                value=50,   
                label="Inference Steps"  
            )  
        
        enhance_btn = gr.Button("Enhance Image Resolution")  
        
        enhance_btn.click(  
            fn=super_resolve_image,   
            inputs=[input_image, num_steps],   
            outputs=output_image  
        )  
        
        # Example images for quick testing  
        gr.Examples(  
            examples=[  
                "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png"  
            ],  
            inputs=input_image  
        )  
    
    return demo  

# Launch the interface  
if __name__ == "__main__":  
    demo = create_gradio_interface()  
    demo.launch(  
        server_name="0.0.0.0",   
        server_port=7860  
    )