aromidvar commited on
Commit
1d87f91
·
verified ·
1 Parent(s): 5daba90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -32
app.py CHANGED
@@ -1,42 +1,41 @@
1
  import gradio as gr
2
- import torch
 
3
  from PIL import Image
4
- from transformers import AutoFeatureExtractor, AutoModelForImageUpscaling
5
 
6
- def load_model():
7
- try:
8
- feature_extractor = AutoFeatureExtractor.from_pretrained("keras-io/super-resolution")
9
- model = AutoModelForImageUpscaling.from_pretrained("keras-io/super-resolution")
10
- return feature_extractor, model
11
- except Exception as e:
12
- print(f"Model loading error: {e}")
13
- return None, None
14
-
15
- def super_resolve_image(input_image):
16
- # Validate input
17
  if input_image is None:
18
  return None
19
 
20
- # Load model (do this once, not in every function call for efficiency)
21
- feature_extractor, model = load_model()
22
-
23
- if model is None:
24
- return "Error: Could not load model"
25
-
26
  try:
27
- # Convert to PIL Image if not already
28
- if not isinstance(input_image, Image.Image):
29
- input_image = Image.fromarray(input_image)
30
 
31
- # Prepare image
32
- inputs = feature_extractor(images=input_image, return_tensors="pt")
 
33
 
34
- # Super-resolve
35
- with torch.no_grad():
36
- outputs = model(**inputs)
37
- enhanced_image = feature_extractor.post_process_image(outputs.image)[0]
 
 
38
 
39
- return enhanced_image
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  except Exception as e:
42
  print(f"Super-resolution error: {e}")
@@ -44,17 +43,24 @@ def super_resolve_image(input_image):
44
 
45
  def create_gradio_interface():
46
  with gr.Blocks() as demo:
47
- gr.Markdown("# 🖼️ Simple Image Super-Resolution")
48
 
49
  with gr.Row():
50
  input_image = gr.Image(label="Input Image", type="pil")
51
- output_image = gr.Image(label="Super-Resolved Image")
 
 
 
 
 
 
 
52
 
53
  enhance_btn = gr.Button("Enhance Image Resolution")
54
 
55
  enhance_btn.click(
56
  fn=super_resolve_image,
57
- inputs=input_image,
58
  outputs=output_image
59
  )
60
 
 
1
  import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
  from PIL import Image
 
5
 
6
+ def super_resolve_image(input_image, scale_factor=4):
 
 
 
 
 
 
 
 
 
 
7
  if input_image is None:
8
  return None
9
 
 
 
 
 
 
 
10
  try:
11
+ # Convert to numpy array if it's a PIL Image
12
+ if isinstance(input_image, Image.Image):
13
+ input_image = np.array(input_image)
14
 
15
+ # Convert to grayscale if needed
16
+ if len(input_image.shape) == 2:
17
+ input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
18
 
19
+ # Different interpolation methods
20
+ methods = [
21
+ cv2.INTER_LINEAR,
22
+ cv2.INTER_CUBIC,
23
+ cv2.INTER_LANCZOS4
24
+ ]
25
 
26
+ # Try different interpolation methods
27
+ for method in methods:
28
+ try:
29
+ upscaled = cv2.resize(
30
+ input_image,
31
+ (input_image.shape[1] * scale_factor, input_image.shape[0] * scale_factor),
32
+ interpolation=method
33
+ )
34
+ return upscaled
35
+ except Exception as e:
36
+ print(f"Interpolation method failed: {method}")
37
+
38
+ return None
39
 
40
  except Exception as e:
41
  print(f"Super-resolution error: {e}")
 
43
 
44
  def create_gradio_interface():
45
  with gr.Blocks() as demo:
46
+ gr.Markdown("# 🖼️ Simple Image Upscaling")
47
 
48
  with gr.Row():
49
  input_image = gr.Image(label="Input Image", type="pil")
50
+ output_image = gr.Image(label="Upscaled Image")
51
+
52
+ with gr.Row():
53
+ scale_dropdown = gr.Dropdown(
54
+ choices=[2, 4, 8],
55
+ value=4,
56
+ label="Upscale Factor"
57
+ )
58
 
59
  enhance_btn = gr.Button("Enhance Image Resolution")
60
 
61
  enhance_btn.click(
62
  fn=super_resolve_image,
63
+ inputs=[input_image, scale_dropdown],
64
  outputs=output_image
65
  )
66