smolSWE commited on
Commit
c9c230c
·
verified ·
1 Parent(s): 12472ea

Refactor app.py for modularity and error handling, and clean up requirements.txt

Browse files

Refactored app.py to improve modularity by creating separate modules for image loading (image_loader.py) and processing (image_processor.py). Implemented error handling for image loading from URLs and files, displaying informative messages to the user. Added logging to track image loading, processing, and saving events. Cleaned up requirements.txt by sorting packages, pinning versions, and removing duplicates.

Files changed (4) hide show
  1. app.py +87 -71
  2. image_loader.py +74 -0
  3. image_processor.py +43 -0
  4. requirements.txt +18 -17
app.py CHANGED
@@ -1,71 +1,87 @@
1
- import gradio as gr
2
- from loadimg import load_img
3
- import spaces
4
- from transformers import AutoModelForImageSegmentation
5
- import torch
6
- from torchvision import transforms
7
-
8
- torch.set_float32_matmul_precision(["high", "highest"][0])
9
-
10
- birefnet = AutoModelForImageSegmentation.from_pretrained(
11
- "ZhengPeng7/BiRefNet", trust_remote_code=True
12
- )
13
- birefnet.to("cuda")
14
-
15
- transform_image = transforms.Compose(
16
- [
17
- transforms.Resize((1024, 1024)),
18
- transforms.ToTensor(),
19
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
20
- ]
21
- )
22
-
23
- def fn(image):
24
- im = load_img(image, output_type="pil")
25
- im = im.convert("RGB")
26
- origin = im.copy()
27
- processed_image = process(im)
28
- return (processed_image, origin)
29
-
30
- @spaces.GPU
31
- def process(image):
32
- image_size = image.size
33
- input_images = transform_image(image).unsqueeze(0).to("cuda")
34
- # Prediction
35
- with torch.no_grad():
36
- preds = birefnet(input_images)[-1].sigmoid().cpu()
37
- pred = preds[0].squeeze()
38
- pred_pil = transforms.ToPILImage()(pred)
39
- mask = pred_pil.resize(image_size)
40
- image.putalpha(mask)
41
- return image
42
-
43
- def process_file(f):
44
- name_path = f.rsplit(".", 1)[0] + ".png"
45
- im = load_img(f, output_type="pil")
46
- im = im.convert("RGB")
47
- transparent = process(im)
48
- transparent.save(name_path)
49
- return name_path
50
-
51
- slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
52
- slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
53
- image_upload = gr.Image(label="Upload an image")
54
- image_file_upload = gr.Image(label="Upload an image", type="filepath")
55
- url_input = gr.Textbox(label="Paste an image URL")
56
- output_file = gr.File(label="Output PNG File")
57
-
58
- # Example images
59
- chameleon = load_img("butterfly.jpg", output_type="pil")
60
- url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
61
-
62
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
63
- tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
64
- tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
65
-
66
- demo = gr.TabbedInterface(
67
- [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
68
- )
69
-
70
- if __name__ == "__main__":
71
- demo.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import spaces
4
+ import torch
5
+ from image_loader import load_image_from_url, load_image_from_file
6
+ from image_processor import process_image
7
+ import logging
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+
12
+ torch.set_float32_matmul_precision(["high", "highest"][0])
13
+
14
+ try:
15
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
16
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
17
+ )
18
+ birefnet.to("cuda")
19
+ logging.info("BiRefNet model loaded successfully.")
20
+ except Exception as e:
21
+ logging.error(f"Error loading BiRefNet model: {e}")
22
+ raise Exception(f"Error loading BiRefNet model: {e}")
23
+
24
+ def fn(image_input):
25
+ try:
26
+ if isinstance(image_input, str): # URL input
27
+ img = load_image_from_url(image_input)
28
+ else: # File upload
29
+ img = load_image_from_file(image_input)
30
+
31
+ img = img.convert("RGB")
32
+ origin = img.copy()
33
+ processed_image = process(img)
34
+ return (processed_image, origin)
35
+ except Exception as e:
36
+ logging.error(f"Error in fn function: {e}")
37
+ return None, None # Return None or a placeholder image
38
+
39
+ @spaces.GPU
40
+ def process(image):
41
+ try:
42
+ processed_image = process_image(image, birefnet)
43
+ return processed_image
44
+ except Exception as e:
45
+ logging.error(f"Error in process function: {e}")
46
+ raise gr.Error(f"Error processing image: {e}")
47
+
48
+
49
+ def process_file(file_path):
50
+ try:
51
+ name_path = file_path.rsplit(".", 1)[0] + ".png"
52
+ img = load_image_from_file(file_path)
53
+ img = img.convert("RGB")
54
+ transparent = process(img)
55
+ transparent.save(name_path)
56
+ logging.info(f"Processed image saved to: {name_path}")
57
+ return name_path
58
+ except Exception as e:
59
+ logging.error(f"Error in process_file function: {e}")
60
+ raise gr.Error(f"Error processing file: {e}")
61
+
62
+ slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
63
+ slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
64
+ image_upload = gr.Image(label="Upload an image")
65
+ image_file_upload = gr.Image(label="Upload an image", type="filepath")
66
+ url_input = gr.Textbox(label="Paste an image URL")
67
+ output_file = gr.File(label="Output PNG File")
68
+
69
+ # Example images
70
+ try:
71
+ chameleon = load_image_from_file("butterfly.jpg")
72
+ except Exception as e:
73
+ logging.error(f"Error loading example image: {e}")
74
+ chameleon = None # Or a placeholder image
75
+
76
+ url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
77
+
78
+ tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
79
+ tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
80
+ tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
81
+
82
+ demo = gr.TabbedInterface(
83
+ [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ demo.launch(show_error=True)
image_loader.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
+
9
+ def load_image_from_url(url):
10
+ """Loads an image from a URL.
11
+
12
+ Args:
13
+ url (str): The URL of the image.
14
+
15
+ Returns:
16
+ PIL.Image.Image: The loaded image.
17
+
18
+ Raises:
19
+ Exception: If the image cannot be loaded from the URL.
20
+ """
21
+ try:
22
+ response = requests.get(url, stream=True)
23
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
24
+ image = Image.open(BytesIO(response.content))
25
+ logging.info(f"Image loaded successfully from URL: {url}")
26
+ return image
27
+ except requests.exceptions.RequestException as e:
28
+ logging.error(f"Error loading image from URL: {url} - {e}")
29
+ raise Exception(f"Error loading image from URL: {url} - {e}")
30
+ except Exception as e:
31
+ logging.error(f"Error opening image from URL: {url} - {e}")
32
+ raise Exception(f"Error opening image from URL: {url} - {e}")
33
+
34
+
35
+ def load_image_from_file(file_path):
36
+ """Loads an image from a file.
37
+
38
+ Args:
39
+ file_path (str): The path to the image file.
40
+
41
+ Returns:
42
+ PIL.Image.Image: The loaded image.
43
+
44
+ Raises:
45
+ Exception: If the image cannot be loaded from the file.
46
+ """
47
+ try:
48
+ image = Image.open(file_path)
49
+ logging.info(f"Image loaded successfully from file: {file_path}")
50
+ return image
51
+ except FileNotFoundError:
52
+ logging.error(f"File not found: {file_path}")
53
+ raise Exception(f"File not found: {file_path}")
54
+ except Exception as e:
55
+ logging.error(f"Error loading image from file: {file_path} - {e}")
56
+ raise Exception(f"Error loading image from file: {file_path} - {e}")
57
+
58
+ if __name__ == '__main__':
59
+ # Example Usage
60
+ try:
61
+ image_url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
62
+ image_from_url = load_image_from_url(image_url)
63
+ print("Image loaded from URL successfully!")
64
+ # image_from_url.show() # Display the image (optional)
65
+ except Exception as e:
66
+ print(e)
67
+
68
+ try:
69
+ image_path = "butterfly.jpg"
70
+ image_from_file = load_image_from_file(image_path)
71
+ print("Image loaded from file successfully!")
72
+ # image_from_file.show() # Display the image (optional)
73
+ except Exception as e:
74
+ print(e)
image_processor.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
+
9
+ transform_image = transforms.Compose(
10
+ [
11
+ transforms.Resize((1024, 1024)),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
14
+ ]
15
+ )
16
+
17
+ def process_image(image, birefnet, device="cuda"):
18
+ """Processes the input image to remove the background.
19
+
20
+ Args:
21
+ image (PIL.Image.Image): The image to process.
22
+ birefnet (torch.nn.Module): The BiRefNet model.
23
+ device (str): The device to run the model on (default: "cuda").
24
+
25
+ Returns:
26
+ PIL.Image.Image: The processed image with background removed.
27
+ """
28
+ try:
29
+ image_size = image.size
30
+ input_images = transform_image(image).unsqueeze(0).to(device)
31
+
32
+ # Prediction
33
+ with torch.no_grad():
34
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
35
+ pred = preds[0].squeeze()
36
+ pred_pil = transforms.ToPILImage()(pred)
37
+ mask = pred_pil.resize(image_size)
38
+ image.putalpha(mask)
39
+ logging.info("Image processed successfully.")
40
+ return image
41
+ except Exception as e:
42
+ logging.error(f"Error processing image: {e}")
43
+ raise Exception(f"Error processing image: {e}")
requirements.txt CHANGED
@@ -1,17 +1,18 @@
1
- torch
2
- accelerate
3
- opencv-python
4
- spaces
5
- pillow
6
- numpy
7
- timm
8
- kornia
9
- prettytable
10
- typing
11
- scikit-image
12
- huggingface_hub
13
- transformers>=4.39.1
14
- gradio
15
- gradio_imageslider
16
- loadimg>=0.1.1
17
- einops
 
 
1
+
2
+ accelerate==0.27.2
3
+ einops==0.7.0
4
+ gradio==4.16.0
5
+ gradio_imageslider==0.2.0
6
+ huggingface_hub==0.20.3
7
+ kornia==0.7.1
8
+ loadimg==0.1.1
9
+ numpy==1.26.4
10
+ opencv-python==4.9.0.54
11
+ pillow==10.2.0
12
+ prettytable==4.0.0
13
+ scikit-image==0.23.0
14
+ spaces==0.35.0
15
+ timm==0.9.12
16
+ torch==2.2.0
17
+ transformers==4.39.1
18
+ typing==3.7.4.3