Spaces:
Running
on
Zero
Running
on
Zero
Refactor app.py for modularity and error handling, and clean up requirements.txt
Browse filesRefactored app.py to improve modularity by creating separate modules for image loading (image_loader.py) and processing (image_processor.py). Implemented error handling for image loading from URLs and files, displaying informative messages to the user. Added logging to track image loading, processing, and saving events. Cleaned up requirements.txt by sorting packages, pinning versions, and removing duplicates.
- app.py +87 -71
- image_loader.py +74 -0
- image_processor.py +43 -0
- requirements.txt +18 -17
app.py
CHANGED
@@ -1,71 +1,87 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import spaces
|
4 |
-
|
5 |
-
import
|
6 |
-
from
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
)
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import spaces
|
4 |
+
import torch
|
5 |
+
from image_loader import load_image_from_url, load_image_from_file
|
6 |
+
from image_processor import process_image
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Configure logging
|
10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
11 |
+
|
12 |
+
torch.set_float32_matmul_precision(["high", "highest"][0])
|
13 |
+
|
14 |
+
try:
|
15 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
16 |
+
"ZhengPeng7/BiRefNet", trust_remote_code=True
|
17 |
+
)
|
18 |
+
birefnet.to("cuda")
|
19 |
+
logging.info("BiRefNet model loaded successfully.")
|
20 |
+
except Exception as e:
|
21 |
+
logging.error(f"Error loading BiRefNet model: {e}")
|
22 |
+
raise Exception(f"Error loading BiRefNet model: {e}")
|
23 |
+
|
24 |
+
def fn(image_input):
|
25 |
+
try:
|
26 |
+
if isinstance(image_input, str): # URL input
|
27 |
+
img = load_image_from_url(image_input)
|
28 |
+
else: # File upload
|
29 |
+
img = load_image_from_file(image_input)
|
30 |
+
|
31 |
+
img = img.convert("RGB")
|
32 |
+
origin = img.copy()
|
33 |
+
processed_image = process(img)
|
34 |
+
return (processed_image, origin)
|
35 |
+
except Exception as e:
|
36 |
+
logging.error(f"Error in fn function: {e}")
|
37 |
+
return None, None # Return None or a placeholder image
|
38 |
+
|
39 |
+
@spaces.GPU
|
40 |
+
def process(image):
|
41 |
+
try:
|
42 |
+
processed_image = process_image(image, birefnet)
|
43 |
+
return processed_image
|
44 |
+
except Exception as e:
|
45 |
+
logging.error(f"Error in process function: {e}")
|
46 |
+
raise gr.Error(f"Error processing image: {e}")
|
47 |
+
|
48 |
+
|
49 |
+
def process_file(file_path):
|
50 |
+
try:
|
51 |
+
name_path = file_path.rsplit(".", 1)[0] + ".png"
|
52 |
+
img = load_image_from_file(file_path)
|
53 |
+
img = img.convert("RGB")
|
54 |
+
transparent = process(img)
|
55 |
+
transparent.save(name_path)
|
56 |
+
logging.info(f"Processed image saved to: {name_path}")
|
57 |
+
return name_path
|
58 |
+
except Exception as e:
|
59 |
+
logging.error(f"Error in process_file function: {e}")
|
60 |
+
raise gr.Error(f"Error processing file: {e}")
|
61 |
+
|
62 |
+
slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
|
63 |
+
slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
|
64 |
+
image_upload = gr.Image(label="Upload an image")
|
65 |
+
image_file_upload = gr.Image(label="Upload an image", type="filepath")
|
66 |
+
url_input = gr.Textbox(label="Paste an image URL")
|
67 |
+
output_file = gr.File(label="Output PNG File")
|
68 |
+
|
69 |
+
# Example images
|
70 |
+
try:
|
71 |
+
chameleon = load_image_from_file("butterfly.jpg")
|
72 |
+
except Exception as e:
|
73 |
+
logging.error(f"Error loading example image: {e}")
|
74 |
+
chameleon = None # Or a placeholder image
|
75 |
+
|
76 |
+
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
|
77 |
+
|
78 |
+
tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
|
79 |
+
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
|
80 |
+
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
|
81 |
+
|
82 |
+
demo = gr.TabbedInterface(
|
83 |
+
[tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
|
84 |
+
)
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
demo.launch(show_error=True)
|
image_loader.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
from io import BytesIO
|
5 |
+
import logging
|
6 |
+
|
7 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
8 |
+
|
9 |
+
def load_image_from_url(url):
|
10 |
+
"""Loads an image from a URL.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
url (str): The URL of the image.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
PIL.Image.Image: The loaded image.
|
17 |
+
|
18 |
+
Raises:
|
19 |
+
Exception: If the image cannot be loaded from the URL.
|
20 |
+
"""
|
21 |
+
try:
|
22 |
+
response = requests.get(url, stream=True)
|
23 |
+
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
24 |
+
image = Image.open(BytesIO(response.content))
|
25 |
+
logging.info(f"Image loaded successfully from URL: {url}")
|
26 |
+
return image
|
27 |
+
except requests.exceptions.RequestException as e:
|
28 |
+
logging.error(f"Error loading image from URL: {url} - {e}")
|
29 |
+
raise Exception(f"Error loading image from URL: {url} - {e}")
|
30 |
+
except Exception as e:
|
31 |
+
logging.error(f"Error opening image from URL: {url} - {e}")
|
32 |
+
raise Exception(f"Error opening image from URL: {url} - {e}")
|
33 |
+
|
34 |
+
|
35 |
+
def load_image_from_file(file_path):
|
36 |
+
"""Loads an image from a file.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
file_path (str): The path to the image file.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
PIL.Image.Image: The loaded image.
|
43 |
+
|
44 |
+
Raises:
|
45 |
+
Exception: If the image cannot be loaded from the file.
|
46 |
+
"""
|
47 |
+
try:
|
48 |
+
image = Image.open(file_path)
|
49 |
+
logging.info(f"Image loaded successfully from file: {file_path}")
|
50 |
+
return image
|
51 |
+
except FileNotFoundError:
|
52 |
+
logging.error(f"File not found: {file_path}")
|
53 |
+
raise Exception(f"File not found: {file_path}")
|
54 |
+
except Exception as e:
|
55 |
+
logging.error(f"Error loading image from file: {file_path} - {e}")
|
56 |
+
raise Exception(f"Error loading image from file: {file_path} - {e}")
|
57 |
+
|
58 |
+
if __name__ == '__main__':
|
59 |
+
# Example Usage
|
60 |
+
try:
|
61 |
+
image_url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
|
62 |
+
image_from_url = load_image_from_url(image_url)
|
63 |
+
print("Image loaded from URL successfully!")
|
64 |
+
# image_from_url.show() # Display the image (optional)
|
65 |
+
except Exception as e:
|
66 |
+
print(e)
|
67 |
+
|
68 |
+
try:
|
69 |
+
image_path = "butterfly.jpg"
|
70 |
+
image_from_file = load_image_from_file(image_path)
|
71 |
+
print("Image loaded from file successfully!")
|
72 |
+
# image_from_file.show() # Display the image (optional)
|
73 |
+
except Exception as e:
|
74 |
+
print(e)
|
image_processor.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import logging
|
6 |
+
|
7 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
8 |
+
|
9 |
+
transform_image = transforms.Compose(
|
10 |
+
[
|
11 |
+
transforms.Resize((1024, 1024)),
|
12 |
+
transforms.ToTensor(),
|
13 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
14 |
+
]
|
15 |
+
)
|
16 |
+
|
17 |
+
def process_image(image, birefnet, device="cuda"):
|
18 |
+
"""Processes the input image to remove the background.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
image (PIL.Image.Image): The image to process.
|
22 |
+
birefnet (torch.nn.Module): The BiRefNet model.
|
23 |
+
device (str): The device to run the model on (default: "cuda").
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
PIL.Image.Image: The processed image with background removed.
|
27 |
+
"""
|
28 |
+
try:
|
29 |
+
image_size = image.size
|
30 |
+
input_images = transform_image(image).unsqueeze(0).to(device)
|
31 |
+
|
32 |
+
# Prediction
|
33 |
+
with torch.no_grad():
|
34 |
+
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
35 |
+
pred = preds[0].squeeze()
|
36 |
+
pred_pil = transforms.ToPILImage()(pred)
|
37 |
+
mask = pred_pil.resize(image_size)
|
38 |
+
image.putalpha(mask)
|
39 |
+
logging.info("Image processed successfully.")
|
40 |
+
return image
|
41 |
+
except Exception as e:
|
42 |
+
logging.error(f"Error processing image: {e}")
|
43 |
+
raise Exception(f"Error processing image: {e}")
|
requirements.txt
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
-
|
2 |
-
accelerate
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
1 |
+
|
2 |
+
accelerate==0.27.2
|
3 |
+
einops==0.7.0
|
4 |
+
gradio==4.16.0
|
5 |
+
gradio_imageslider==0.2.0
|
6 |
+
huggingface_hub==0.20.3
|
7 |
+
kornia==0.7.1
|
8 |
+
loadimg==0.1.1
|
9 |
+
numpy==1.26.4
|
10 |
+
opencv-python==4.9.0.54
|
11 |
+
pillow==10.2.0
|
12 |
+
prettytable==4.0.0
|
13 |
+
scikit-image==0.23.0
|
14 |
+
spaces==0.35.0
|
15 |
+
timm==0.9.12
|
16 |
+
torch==2.2.0
|
17 |
+
transformers==4.39.1
|
18 |
+
typing==3.7.4.3
|