Saad0KH commited on
Commit
637177b
·
verified ·
1 Parent(s): 954e47d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -81
app.py CHANGED
@@ -1,124 +1,173 @@
1
- from flask import Flask, request, jsonify, send_file
2
  from PIL import Image
3
  import base64
4
- import threading
5
- import asyncio
6
- import torch
7
  from io import BytesIO
8
  import numpy as np
 
 
 
 
 
 
 
9
  import uuid
10
  from transformers import AutoModelForImageSegmentation
 
11
  from torchvision import transforms
12
- import logging
13
- import tempfile
14
- from concurrent.futures import ThreadPoolExecutor
15
 
16
- # Initialize Flask app
17
  app = Flask(__name__)
18
 
19
  # Configure logging
20
  logging.basicConfig(level=logging.INFO)
21
 
22
- # ThreadPool for async tasks
23
- executor = ThreadPoolExecutor(max_workers=4)
24
-
25
- # GPU model setup
26
- birefnet = None
27
- transform_image = None
28
- model_loaded = threading.Event()
29
 
30
  def load_model():
31
- global birefnet, transform_image
32
- if not model_loaded.is_set():
33
- birefnet = AutoModelForImageSegmentation.from_pretrained(
34
- "ZhengPeng7/BiRefNet", trust_remote_code=True
35
- )
36
- birefnet.to("cuda")
37
- birefnet.eval()
38
- transform_image = transforms.Compose([
39
- transforms.Resize((1024, 1024)),
40
- transforms.ToTensor(),
41
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
42
- ])
43
- model_loaded.set()
44
-
45
- # Helper functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def decode_image_from_base64(image_data):
47
  image_data = base64.b64decode(image_data)
48
  image = Image.open(BytesIO(image_data)).convert("RGB")
49
  return image
50
 
 
51
  def encode_image_to_base64(image):
52
  buffered = BytesIO()
53
- image.save(buffered, format="PNG")
54
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
55
-
56
- def save_image(img):
57
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
58
- img.save(temp_file.name)
59
- return temp_file.name
60
-
61
- def cleanup_gpu_resources():
62
- torch.cuda.empty_cache()
63
- torch.cuda.ipc_collect()
64
-
65
- async def process_image(image):
66
- """Process the image asynchronously, including background removal."""
67
- global birefnet, transform_image
68
-
69
- # Ensure the model is loaded
70
- load_model()
71
-
72
- # Convert image to tensor
73
  input_images = transform_image(image).unsqueeze(0).to("cuda")
74
-
75
- # Run inference
76
  with torch.no_grad():
77
  preds = birefnet(input_images)[-1].sigmoid().cpu()
78
-
79
- # Generate mask and apply to original image
80
  pred = preds[0].squeeze()
81
  pred_pil = transforms.ToPILImage()(pred)
82
- mask = pred_pil.resize(image.size)
83
  image.putalpha(mask)
84
-
85
- # Cleanup GPU resources
86
- del input_images, preds, pred
87
- cleanup_gpu_resources()
88
-
89
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  @app.route('/api/detect', methods=['POST'])
92
- async def detect():
93
  try:
94
  data = request.json
95
- image_base64 = data.get('image')
96
-
97
- if not image_base64:
98
- return jsonify({"error": "No image provided."}), 400
99
-
100
- # Decode the image
101
  image = decode_image_from_base64(image_base64)
102
 
103
- # Process the image asynchronously
104
- loop = asyncio.get_event_loop()
105
- processed_image = await loop.run_in_executor(executor, asyncio.run, process_image(image))
106
 
107
- # Save the processed image and encode it as base64
108
- output_path = save_image(processed_image)
109
-
110
- return jsonify({"image_url": f"/api/get_image/{uuid.uuid4()}", "path": output_path})
111
 
 
112
  except Exception as e:
113
- logging.error(f"Error during detection: {e}")
114
- return jsonify({"error": str(e)}), 500
115
-
 
116
  @app.route('/api/get_image/<image_id>', methods=['GET'])
117
  def get_image(image_id):
 
 
 
 
118
  try:
119
- return send_file(image_id, mimetype='image/png')
120
  except FileNotFoundError:
121
- return jsonify({"error": "Image not found"}), 404
122
 
123
  if __name__ == "__main__":
124
- app.run(debug=True, host="0.0.0.0", port=7860)
 
1
+ from flask import Flask, request, jsonify ,send_file
2
  from PIL import Image
3
  import base64
4
+ import spaces
5
+ from loadimg import load_img
 
6
  from io import BytesIO
7
  import numpy as np
8
+ import insightface
9
+ import onnxruntime as ort
10
+ import huggingface_hub
11
+ from SegCloth import segment_clothing
12
+ from transparent_background import Remover
13
+ import threading
14
+ import logging
15
  import uuid
16
  from transformers import AutoModelForImageSegmentation
17
+ import torch
18
  from torchvision import transforms
 
 
 
19
 
20
+
21
  app = Flask(__name__)
22
 
23
  # Configure logging
24
  logging.basicConfig(level=logging.INFO)
25
 
26
+ # Load the model lazily
27
+ model = None
28
+ detector = None
 
 
 
 
29
 
30
  def load_model():
31
+ global model, detector
32
+ path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
33
+ options = ort.SessionOptions()
34
+ options.intra_op_num_threads = 8
35
+ options.inter_op_num_threads = 8
36
+ session = ort.InferenceSession(
37
+ path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
38
+ )
39
+ model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
40
+ model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
41
+ detector = model
42
+ logging.info("Model loaded successfully.")
43
+
44
+ torch.set_float32_matmul_precision(["high", "highest"][0])
45
+
46
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
47
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
48
+ )
49
+ birefnet.to("cuda")
50
+ transform_image = transforms.Compose(
51
+ [
52
+ transforms.Resize((1024, 1024)),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
55
+ ]
56
+ )
57
+
58
+
59
+
60
+ def save_image(img):
61
+ unique_name = str(uuid.uuid4()) + ".png"
62
+ img.save(unique_name)
63
+ return unique_name
64
+
65
+ # Function to decode a base64 image to PIL.Image.Image
66
  def decode_image_from_base64(image_data):
67
  image_data = base64.b64decode(image_data)
68
  image = Image.open(BytesIO(image_data)).convert("RGB")
69
  return image
70
 
71
+ # Function to encode a PIL image to base64
72
  def encode_image_to_base64(image):
73
  buffered = BytesIO()
74
+ image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA
75
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
76
+ @spaces.GPU
77
+ def rm_background(image):
78
+ im = load_img(image, output_type="pil")
79
+ im = im.convert("RGB")
80
+ image_size = im.size
81
+ origin = im.copy()
82
+ image = load_img(im)
 
 
 
 
 
 
 
 
 
 
 
83
  input_images = transform_image(image).unsqueeze(0).to("cuda")
84
+ # Prediction
 
85
  with torch.no_grad():
86
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
 
87
  pred = preds[0].squeeze()
88
  pred_pil = transforms.ToPILImage()(pred)
89
+ mask = pred_pil.resize(image_size)
90
  image.putalpha(mask)
91
+ return (image)
92
+
93
+ @spaces.GPU
94
+ def remove_background(image):
95
+ remover = Remover()
96
+ if isinstance(image, Image.Image):
97
+ output = remover.process(image)
98
+ elif isinstance(image, np.ndarray):
99
+ image_pil = Image.fromarray(image)
100
+ output = remover.process(image_pil)
101
+ else:
102
+ raise TypeError("Unsupported image type")
103
+ return output
104
+
105
+ def detect_and_segment_persons(image, clothes):
106
+ img = np.array(image)
107
+ img = img[:, :, ::-1] # RGB -> BGR
108
+
109
+ if detector is None:
110
+ load_model() # Ensure the model is loaded
111
+
112
+ bboxes, kpss = detector.detect(img)
113
+ if bboxes.shape[0] == 0:
114
+ return [save_image(rm_background(image))]
115
+
116
+ height, width, _ = img.shape
117
+ bboxes = np.round(bboxes[:, :4]).astype(int)
118
+ bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
119
+ bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
120
+ bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
121
+ bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
122
+
123
+ all_segmented_images = []
124
+ for i in range(bboxes.shape[0]):
125
+ bbox = bboxes[i]
126
+ x1, y1, x2, y2 = bbox
127
+ person_img = img[y1:y2, x1:x2]
128
+ pil_img = Image.fromarray(person_img[:, :, ::-1])
129
+
130
+ img_rm_background = rm_background(pil_img)
131
+ segmented_result = segment_clothing(img_rm_background, clothes)
132
+ image_paths = [save_image(img) for img in segmented_result]
133
+ print(image_paths)
134
+ all_segmented_images.extend(image_paths)
135
+
136
+ return all_segmented_images
137
+
138
+ @app.route('/', methods=['GET'])
139
+ def welcome():
140
+ return "Welcome to Clothing Segmentation API"
141
 
142
  @app.route('/api/detect', methods=['POST'])
143
+ def detect():
144
  try:
145
  data = request.json
146
+ image_base64 = data['image']
 
 
 
 
 
147
  image = decode_image_from_base64(image_base64)
148
 
149
+ clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
150
+
 
151
 
152
+ result = detect_and_segment_persons(image, clothes)
153
+
 
 
154
 
155
+ return jsonify({'images': result})
156
  except Exception as e:
157
+ logging.error(f"Error occurred: {e}")
158
+ return jsonify({'error': str(e)}), 500
159
+
160
+ # Route pour récupérer l'image générée
161
  @app.route('/api/get_image/<image_id>', methods=['GET'])
162
  def get_image(image_id):
163
+ # Construire le chemin complet de l'image
164
+ image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
165
+
166
+ # Renvoyer l'image
167
  try:
168
+ return send_file(image_path, mimetype='image/png')
169
  except FileNotFoundError:
170
+ return jsonify({'error': 'Image not found'}), 404
171
 
172
  if __name__ == "__main__":
173
+ app.run(debug=True, host="0.0.0.0", port=7860)