Saad0KH commited on
Commit
c6be1d4
Β·
verified Β·
1 Parent(s): 02c3409

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -123
app.py CHANGED
@@ -1,173 +1,123 @@
1
- from flask import Flask, request, jsonify ,send_file
2
  from PIL import Image
3
  import base64
4
- import spaces
5
- from loadimg import load_img
 
6
  from io import BytesIO
7
  import numpy as np
8
- import insightface
9
- import onnxruntime as ort
10
- import huggingface_hub
11
- from SegCloth import segment_clothing
12
- from transparent_background import Remover
13
- import threading
14
- import logging
15
  import uuid
16
  from transformers import AutoModelForImageSegmentation
17
- import torch
18
  from torchvision import transforms
 
 
 
19
 
20
-
21
  app = Flask(__name__)
22
 
23
  # Configure logging
24
  logging.basicConfig(level=logging.INFO)
25
 
26
- # Load the model lazily
27
- model = None
28
- detector = None
 
 
 
29
 
30
  def load_model():
31
- global model, detector
32
- path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
33
- options = ort.SessionOptions()
34
- options.intra_op_num_threads = 8
35
- options.inter_op_num_threads = 8
36
- session = ort.InferenceSession(
37
- path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
38
  )
39
- model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
40
- model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
41
- detector = model
42
- logging.info("Model loaded successfully.")
43
-
44
- torch.set_float32_matmul_precision(["high", "highest"][0])
45
-
46
- birefnet = AutoModelForImageSegmentation.from_pretrained(
47
- "ZhengPeng7/BiRefNet", trust_remote_code=True
48
- )
49
- birefnet.to("cuda")
50
- transform_image = transforms.Compose(
51
- [
52
  transforms.Resize((1024, 1024)),
53
  transforms.ToTensor(),
54
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
55
- ]
56
- )
57
-
58
-
59
 
60
- def save_image(img):
61
- unique_name = str(uuid.uuid4()) + ".png"
62
- img.save(unique_name)
63
- return unique_name
64
 
65
- # Function to decode a base64 image to PIL.Image.Image
66
  def decode_image_from_base64(image_data):
67
  image_data = base64.b64decode(image_data)
68
  image = Image.open(BytesIO(image_data)).convert("RGB")
69
  return image
70
 
71
- # Function to encode a PIL image to base64
72
  def encode_image_to_base64(image):
73
  buffered = BytesIO()
74
- image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA
75
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
76
- @spaces.GPU
77
- def rm_background(image):
78
- im = load_img(image, output_type="pil")
79
- im = im.convert("RGB")
80
- image_size = im.size
81
- origin = im.copy()
82
- image = load_img(im)
 
 
 
 
 
 
 
 
83
  input_images = transform_image(image).unsqueeze(0).to("cuda")
84
- # Prediction
 
85
  with torch.no_grad():
86
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
 
87
  pred = preds[0].squeeze()
88
  pred_pil = transforms.ToPILImage()(pred)
89
- mask = pred_pil.resize(image_size)
90
  image.putalpha(mask)
91
- return (image)
92
-
93
- @spaces.GPU
94
- def remove_background(image):
95
- remover = Remover()
96
- if isinstance(image, Image.Image):
97
- output = remover.process(image)
98
- elif isinstance(image, np.ndarray):
99
- image_pil = Image.fromarray(image)
100
- output = remover.process(image_pil)
101
- else:
102
- raise TypeError("Unsupported image type")
103
- return output
104
-
105
- def detect_and_segment_persons(image, clothes):
106
- img = np.array(image)
107
- img = img[:, :, ::-1] # RGB -> BGR
108
-
109
- if detector is None:
110
- load_model() # Ensure the model is loaded
111
-
112
- bboxes, kpss = detector.detect(img)
113
- if bboxes.shape[0] == 0:
114
- return [save_image(rm_background(image))]
115
-
116
- height, width, _ = img.shape
117
- bboxes = np.round(bboxes[:, :4]).astype(int)
118
- bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
119
- bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
120
- bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
121
- bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
122
-
123
- all_segmented_images = []
124
- for i in range(bboxes.shape[0]):
125
- bbox = bboxes[i]
126
- x1, y1, x2, y2 = bbox
127
- person_img = img[y1:y2, x1:x2]
128
- pil_img = Image.fromarray(person_img[:, :, ::-1])
129
-
130
- img_rm_background = rm_background(pil_img)
131
- segmented_result = segment_clothing(img_rm_background, clothes)
132
- image_paths = [save_image(img) for img in segmented_result]
133
- print(image_paths)
134
- all_segmented_images.extend(image_paths)
135
-
136
- return all_segmented_images
137
-
138
- @app.route('/', methods=['GET'])
139
- def welcome():
140
- return "Welcome to Clothing Segmentation API"
141
 
142
  @app.route('/api/detect', methods=['POST'])
143
- def detect():
144
  try:
145
  data = request.json
146
- image_base64 = data['image']
 
 
 
 
 
147
  image = decode_image_from_base64(image_base64)
148
 
149
- clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
150
-
 
151
 
152
- result = detect_and_segment_persons(image, clothes)
153
-
 
 
154
 
155
- return jsonify({'images': result})
156
  except Exception as e:
157
- logging.error(f"Error occurred: {e}")
158
- return jsonify({'error': str(e)}), 500
159
-
160
- # Route pour rΓ©cupΓ©rer l'image gΓ©nΓ©rΓ©e
161
  @app.route('/api/get_image/<image_id>', methods=['GET'])
162
  def get_image(image_id):
163
- # Construire le chemin complet de l'image
164
- image_path = image_id # Assurez-vous que le nom de fichier correspond Γ  celui que vous avez utilisΓ© lors de la sauvegarde
165
-
166
- # Renvoyer l'image
167
  try:
168
- return send_file(image_path, mimetype='image/png')
169
  except FileNotFoundError:
170
- return jsonify({'error': 'Image not found'}), 404
171
 
172
  if __name__ == "__main__":
173
  app.run(debug=True, host="0.0.0.0", port=7860)
 
1
+ from flask import Flask, request, jsonify, send_file
2
  from PIL import Image
3
  import base64
4
+ import threading
5
+ import asyncio
6
+ import torch
7
  from io import BytesIO
8
  import numpy as np
 
 
 
 
 
 
 
9
  import uuid
10
  from transformers import AutoModelForImageSegmentation
 
11
  from torchvision import transforms
12
+ import logging
13
+ import tempfile
14
+ from concurrent.futures import ThreadPoolExecutor
15
 
16
+ # Initialize Flask app
17
  app = Flask(__name__)
18
 
19
  # Configure logging
20
  logging.basicConfig(level=logging.INFO)
21
 
22
+ # ThreadPool for async tasks
23
+ executor = ThreadPoolExecutor(max_workers=4)
24
+
25
+ # GPU model setup
26
+ birefnet = None
27
+ transform_image = None
28
 
29
  def load_model():
30
+ global birefnet, transform_image
31
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
32
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
 
 
 
 
33
  )
34
+ birefnet.to("cuda")
35
+ birefnet.eval()
36
+ transform_image = transforms.Compose([
 
 
 
 
 
 
 
 
 
 
37
  transforms.Resize((1024, 1024)),
38
  transforms.ToTensor(),
39
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
40
+ ])
 
 
 
41
 
42
+ # Lazy load the model on the first request
43
+ @app.before_first_request
44
+ def initialize():
45
+ threading.Thread(target=load_model).start()
46
 
47
+ # Helper functions
48
  def decode_image_from_base64(image_data):
49
  image_data = base64.b64decode(image_data)
50
  image = Image.open(BytesIO(image_data)).convert("RGB")
51
  return image
52
 
 
53
  def encode_image_to_base64(image):
54
  buffered = BytesIO()
55
+ image.save(buffered, format="PNG")
56
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
57
+
58
+ def save_image(img):
59
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
60
+ img.save(temp_file.name)
61
+ return temp_file.name
62
+
63
+ def cleanup_gpu_resources():
64
+ torch.cuda.empty_cache()
65
+ torch.cuda.ipc_collect()
66
+
67
+ async def process_image(image):
68
+ """Process the image asynchronously, including background removal."""
69
+ global birefnet, transform_image
70
+
71
+ # Convert image to tensor
72
  input_images = transform_image(image).unsqueeze(0).to("cuda")
73
+
74
+ # Run inference
75
  with torch.no_grad():
76
  preds = birefnet(input_images)[-1].sigmoid().cpu()
77
+
78
+ # Generate mask and apply to original image
79
  pred = preds[0].squeeze()
80
  pred_pil = transforms.ToPILImage()(pred)
81
+ mask = pred_pil.resize(image.size)
82
  image.putalpha(mask)
83
+
84
+ # Cleanup GPU resources
85
+ del input_images, preds, pred
86
+ cleanup_gpu_resources()
87
+
88
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  @app.route('/api/detect', methods=['POST'])
91
+ async def detect():
92
  try:
93
  data = request.json
94
+ image_base64 = data.get('image')
95
+
96
+ if not image_base64:
97
+ return jsonify({"error": "No image provided."}), 400
98
+
99
+ # Decode the image
100
  image = decode_image_from_base64(image_base64)
101
 
102
+ # Process the image asynchronously
103
+ loop = asyncio.get_event_loop()
104
+ processed_image = await loop.run_in_executor(executor, asyncio.run, process_image(image))
105
 
106
+ # Save the processed image and encode it as base64
107
+ output_path = save_image(processed_image)
108
+
109
+ return jsonify({"image_url": f"/api/get_image/{uuid.uuid4()}", "path": output_path})
110
 
 
111
  except Exception as e:
112
+ logging.error(f"Error during detection: {e}")
113
+ return jsonify({"error": str(e)}), 500
114
+
 
115
  @app.route('/api/get_image/<image_id>', methods=['GET'])
116
  def get_image(image_id):
 
 
 
 
117
  try:
118
+ return send_file(image_id, mimetype='image/png')
119
  except FileNotFoundError:
120
+ return jsonify({"error": "Image not found"}), 404
121
 
122
  if __name__ == "__main__":
123
  app.run(debug=True, host="0.0.0.0", port=7860)