aromidvar commited on
Commit
9dc758e
·
verified ·
1 Parent(s): 2aaa2f3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoFeatureExtractor, AutoModelForImageUpscaling
5
+
6
+ def load_model():
7
+ try:
8
+ feature_extractor = AutoFeatureExtractor.from_pretrained("keras-io/super-resolution")
9
+ model = AutoModelForImageUpscaling.from_pretrained("keras-io/super-resolution")
10
+ return feature_extractor, model
11
+ except Exception as e:
12
+ print(f"Model loading error: {e}")
13
+ return None, None
14
+
15
+ def super_resolve_image(input_image):
16
+ # Validate input
17
+ if input_image is None:
18
+ return None
19
+
20
+ # Load model (do this once, not in every function call for efficiency)
21
+ feature_extractor, model = load_model()
22
+
23
+ if model is None:
24
+ return "Error: Could not load model"
25
+
26
+ try:
27
+ # Convert to PIL Image if not already
28
+ if not isinstance(input_image, Image.Image):
29
+ input_image = Image.fromarray(input_image)
30
+
31
+ # Prepare image
32
+ inputs = feature_extractor(images=input_image, return_tensors="pt")
33
+
34
+ # Super-resolve
35
+ with torch.no_grad():
36
+ outputs = model(**inputs)
37
+ enhanced_image = feature_extractor.post_process_image(outputs.image)[0]
38
+
39
+ return enhanced_image
40
+
41
+ except Exception as e:
42
+ print(f"Super-resolution error: {e}")
43
+ return None
44
+
45
+ def create_gradio_interface():
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("# 🖼️ Simple Image Super-Resolution")
48
+
49
+ with gr.Row():
50
+ input_image = gr.Image(label="Input Image", type="pil")
51
+ output_image = gr.Image(label="Super-Resolved Image")
52
+
53
+ enhance_btn = gr.Button("Enhance Image Resolution")
54
+
55
+ enhance_btn.click(
56
+ fn=super_resolve_image,
57
+ inputs=input_image,
58
+ outputs=output_image
59
+ )
60
+
61
+ return demo
62
+
63
+ # Launch the interface
64
+ if __name__ == "__main__":
65
+ demo = create_gradio_interface()
66
+ demo.launch(
67
+ server_name="0.0.0.0",
68
+ server_port=7860
69
+ )