aromidvar commited on
Commit
aa8b76a
·
verified ·
1 Parent(s): 93b81d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -27
app.py CHANGED
@@ -1,24 +1,9 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
- import torch
5
- from basicsr.archs.rrdbnet_arch import RRDBNet
6
- from realesrgan import RealESRGANer
7
 
8
- def load_model():
9
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
10
- upsampler = RealESRGANer(
11
- scale=4,
12
- model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
13
- model=model,
14
- tile=0,
15
- tile_pad=10,
16
- pre_pad=0,
17
- half=False
18
- )
19
- return upsampler
20
-
21
- def super_resolve_image(input_image):
22
  if input_image is None:
23
  return None
24
 
@@ -27,17 +12,30 @@ def super_resolve_image(input_image):
27
  if isinstance(input_image, Image.Image):
28
  input_image = np.array(input_image)
29
 
30
- # Ensure the image is in RGB format
31
  if len(input_image.shape) == 2:
32
- input_image = np.stack([input_image]*3, axis=-1)
33
 
34
- # Load model
35
- upsampler = load_model()
 
 
 
 
36
 
37
- # Super-resolve
38
- output, _ = upsampler.enhance(input_image)
 
 
 
 
 
 
 
 
 
39
 
40
- return output
41
 
42
  except Exception as e:
43
  print(f"Super-resolution error: {e}")
@@ -45,17 +43,24 @@ def super_resolve_image(input_image):
45
 
46
  def create_gradio_interface():
47
  with gr.Blocks() as demo:
48
- gr.Markdown("# 🖼️ Image Super-Resolution with Real-ESRGAN")
49
 
50
  with gr.Row():
51
  input_image = gr.Image(label="Input Image", type="pil")
52
- output_image = gr.Image(label="Super-Resolved Image")
 
 
 
 
 
 
 
53
 
54
  enhance_btn = gr.Button("Enhance Image Resolution")
55
 
56
  enhance_btn.click(
57
  fn=super_resolve_image,
58
- inputs=input_image,
59
  outputs=output_image
60
  )
61
 
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
+ import cv2
 
 
5
 
6
+ def super_resolve_image(input_image, scale_factor=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  if input_image is None:
8
  return None
9
 
 
12
  if isinstance(input_image, Image.Image):
13
  input_image = np.array(input_image)
14
 
15
+ # Ensure image is in RGB
16
  if len(input_image.shape) == 2:
17
+ input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
18
 
19
+ # Upscaling methods
20
+ upscaling_methods = [
21
+ cv2.INTER_LINEAR, # Bilinear interpolation
22
+ cv2.INTER_CUBIC, # Bicubic interpolation
23
+ cv2.INTER_LANCZOS4 # Lanczos interpolation
24
+ ]
25
 
26
+ # Try different interpolation methods
27
+ for method in upscaling_methods:
28
+ try:
29
+ upscaled = cv2.resize(
30
+ input_image,
31
+ (input_image.shape[1] * scale_factor, input_image.shape[0] * scale_factor),
32
+ interpolation=method
33
+ )
34
+ return upscaled
35
+ except Exception as e:
36
+ print(f"Interpolation method failed: {method}")
37
 
38
+ return None
39
 
40
  except Exception as e:
41
  print(f"Super-resolution error: {e}")
 
43
 
44
  def create_gradio_interface():
45
  with gr.Blocks() as demo:
46
+ gr.Markdown("# 🖼️ Simple Image Upscaling")
47
 
48
  with gr.Row():
49
  input_image = gr.Image(label="Input Image", type="pil")
50
+ output_image = gr.Image(label="Upscaled Image")
51
+
52
+ with gr.Row():
53
+ scale_dropdown = gr.Dropdown(
54
+ choices=[2, 4, 8],
55
+ value=4,
56
+ label="Upscale Factor"
57
+ )
58
 
59
  enhance_btn = gr.Button("Enhance Image Resolution")
60
 
61
  enhance_btn.click(
62
  fn=super_resolve_image,
63
+ inputs=[input_image, scale_dropdown],
64
  outputs=output_image
65
  )
66