Saad0KH commited on
Commit
dabea56
·
verified ·
1 Parent(s): bc7a022

Create app_multiple.py

Browse files
Files changed (1) hide show
  1. app_multiple.py +276 -0
app_multiple.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify ,send_file
2
+ from PIL import Image
3
+ import base64
4
+ import spaces
5
+ from loadimg import load_img
6
+ from io import BytesIO
7
+ import numpy as np
8
+ import insightface
9
+ import onnxruntime as ort
10
+ import huggingface_hub
11
+ from SegCloth import segment_clothing
12
+ from transparent_background import Remover
13
+ import threading
14
+ import logging
15
+ import uuid
16
+ from transformers import AutoModelForImageSegmentation,AutoModelForCausalLM, AutoProcessor
17
+ import torch
18
+ from torchvision import transforms
19
+ import subprocess
20
+ import logging
21
+ import json
22
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
+
24
+ app = Flask(__name__)
25
+
26
+ kwargs = {}
27
+ kwargs['torch_dtype'] = torch.bfloat16
28
+
29
+ models = {
30
+ "microsoft/Phi-3-vision-128k-instruct": AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval()
31
+ }
32
+
33
+ processors = {
34
+ "microsoft/Phi-3-vision-128k-instruct": AutoProcessor.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True)
35
+ }
36
+
37
+ subprocess.run(
38
+ "pip install flash-attn --no-build-isolation",
39
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
40
+ shell=True,
41
+ )
42
+
43
+ user_prompt = '<|user|>\n'
44
+ assistant_prompt = '<|assistant|>\n'
45
+ prompt_suffix = "<|end|>\n"
46
+
47
+ def get_image_from_url(url):
48
+ try:
49
+ response = requests.get(url)
50
+ response.raise_for_status() # Vérifie les erreurs HTTP
51
+ img = Image.open(BytesIO(response.content))
52
+ return img
53
+ except Exception as e:
54
+ logging.error(f"Error fetching image from URL: {e}")
55
+ raise
56
+
57
+
58
+ # Function to decode a base64 image to PIL.Image.Image
59
+ def decode_image_from_base64(image_data):
60
+ image_data = base64.b64decode(image_data)
61
+ image = Image.open(BytesIO(image_data)).convert("RGB")
62
+ return image
63
+
64
+ # Function to encode a PIL image to base64
65
+ def encode_image_to_base64(image):
66
+ buffered = BytesIO()
67
+ image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA
68
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
69
+
70
+ def get_image(image_data):
71
+ # Vérifie si l'image est en base64 ou URL
72
+ if image_data.startswith('http://') or image_data.startswith('https://'):
73
+ return get_image_from_url(image_data) # Télécharge l'image depuis l'URL
74
+ else:
75
+ return decode_image_from_base64(image_data) # Décode l'image base64
76
+
77
+ @spaces.GPU
78
+ def process_vision(image, text_input=None, model_id="microsoft/Phi-3-vision-128k-instruct"):
79
+ model = models[model_id]
80
+ processor = processors[model_id]
81
+
82
+ prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
83
+ image = image.convert("RGB")
84
+
85
+ inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
86
+ generate_ids = model.generate(**inputs,
87
+ max_new_tokens=4128,
88
+ eos_token_id=processor.tokenizer.eos_token_id,
89
+ )
90
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
91
+ response = processor.batch_decode(generate_ids,
92
+ skip_special_tokens=True,
93
+ clean_up_tokenization_spaces=False)[0]
94
+ return response
95
+
96
+
97
+ @app.route('/api/vision', methods=['POST'])
98
+ def detect():
99
+ try:
100
+ data = request.json
101
+ image = data['image']
102
+ prompt = data['prompt']
103
+ image = get_image(image)
104
+ result = process_vision(image,prompt)
105
+
106
+ # Remove ```json and ``` markers
107
+ if result.startswith("```json"):
108
+ result = result[7:] # Remove the leading ```json
109
+ if result.endswith("```"):
110
+ result = result[:-3] # Remove the trailing ```
111
+
112
+ # Convert the string result to a Python dictionary
113
+ try:
114
+ logging.info(result)
115
+ result_dict = json.loads(result)
116
+ except json.JSONDecodeError as e:
117
+ logging.error(f"JSON decoding error: {e}")
118
+ return jsonify({'error': 'Invalid JSON format in the response'}), 500
119
+
120
+
121
+ return jsonify(result_dict)
122
+ except Exception as e:
123
+ logging.error(f"Error occurred: {e}")
124
+ return jsonify({'error': str(e)}), 500
125
+
126
+ # Configure logging
127
+ logging.basicConfig(level=logging.INFO)
128
+
129
+ # Load the model lazily
130
+ model = None
131
+ detector = None
132
+
133
+ def load_model():
134
+ global model, detector
135
+ path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
136
+ options = ort.SessionOptions()
137
+ options.intra_op_num_threads = 8
138
+ options.inter_op_num_threads = 8
139
+ session = ort.InferenceSession(
140
+ path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
141
+ )
142
+ model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
143
+ model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
144
+ detector = model
145
+ logging.info("Model loaded successfully.")
146
+
147
+ torch.set_float32_matmul_precision(["high", "highest"][0])
148
+
149
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
150
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
151
+ )
152
+ birefnet.to("cuda")
153
+ transform_image = transforms.Compose(
154
+ [
155
+ transforms.Resize((1024, 1024)),
156
+ transforms.ToTensor(),
157
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
158
+ ]
159
+ )
160
+
161
+
162
+
163
+ def save_image(img):
164
+ unique_name = str(uuid.uuid4()) + ".png"
165
+ img.save(unique_name)
166
+ return unique_name
167
+
168
+ # Function to decode a base64 image to PIL.Image.Image
169
+ def decode_image_from_base64(image_data):
170
+ image_data = base64.b64decode(image_data)
171
+ image = Image.open(BytesIO(image_data)).convert("RGB")
172
+ return image
173
+
174
+ # Function to encode a PIL image to base64
175
+ def encode_image_to_base64(image):
176
+ buffered = BytesIO()
177
+ image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA
178
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
179
+ @spaces.GPU
180
+ def rm_background(image):
181
+ im = load_img(image, output_type="pil")
182
+ im = im.convert("RGB")
183
+ image_size = im.size
184
+ origin = im.copy()
185
+ image = load_img(im)
186
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
187
+ # Prediction
188
+ with torch.no_grad():
189
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
190
+ pred = preds[0].squeeze()
191
+ pred_pil = transforms.ToPILImage()(pred)
192
+ mask = pred_pil.resize(image_size)
193
+ image.putalpha(mask)
194
+ return (image)
195
+
196
+ @spaces.GPU
197
+ def remove_background(image):
198
+ remover = Remover()
199
+ if isinstance(image, Image.Image):
200
+ output = remover.process(image)
201
+ elif isinstance(image, np.ndarray):
202
+ image_pil = Image.fromarray(image)
203
+ output = remover.process(image_pil)
204
+ else:
205
+ raise TypeError("Unsupported image type")
206
+ return output
207
+
208
+ def detect_and_segment_persons(image, clothes):
209
+ img = np.array(image)
210
+ img = img[:, :, ::-1] # RGB -> BGR
211
+
212
+ if detector is None:
213
+ load_model() # Ensure the model is loaded
214
+
215
+ bboxes, kpss = detector.detect(img)
216
+ if bboxes.shape[0] == 0:
217
+ return [save_image(rm_background(image))]
218
+
219
+ height, width, _ = img.shape
220
+ bboxes = np.round(bboxes[:, :4]).astype(int)
221
+ bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
222
+ bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
223
+ bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
224
+ bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
225
+
226
+ all_segmented_images = []
227
+ for i in range(bboxes.shape[0]):
228
+ bbox = bboxes[i]
229
+ x1, y1, x2, y2 = bbox
230
+ person_img = img[y1:y2, x1:x2]
231
+ pil_img = Image.fromarray(person_img[:, :, ::-1])
232
+
233
+ img_rm_background = rm_background(pil_img)
234
+ segmented_result = segment_clothing(img_rm_background, clothes)
235
+ image_paths = [save_image(img) for img in segmented_result]
236
+ print(image_paths)
237
+ all_segmented_images.extend(image_paths)
238
+
239
+ return all_segmented_images
240
+
241
+ @app.route('/', methods=['GET'])
242
+ def welcome():
243
+ return "Welcome to Clothing Segmentation API"
244
+
245
+ @app.route('/api/detect', methods=['POST'])
246
+ def detect():
247
+ try:
248
+ data = request.json
249
+ image_base64 = data['image']
250
+ image = decode_image_from_base64(image_base64)
251
+
252
+ clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
253
+
254
+
255
+ result = detect_and_segment_persons(image, clothes)
256
+
257
+
258
+ return jsonify({'images': result})
259
+ except Exception as e:
260
+ logging.error(f"Error occurred: {e}")
261
+ return jsonify({'error': str(e)}), 500
262
+
263
+ # Route pour récupérer l'image générée
264
+ @app.route('/api/get_image/<image_id>', methods=['GET'])
265
+ def get_image(image_id):
266
+ # Construire le chemin complet de l'image
267
+ image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
268
+
269
+ # Renvoyer l'image
270
+ try:
271
+ return send_file(image_path, mimetype='image/png')
272
+ except FileNotFoundError:
273
+ return jsonify({'error': 'Image not found'}), 404
274
+
275
+ if __name__ == "__main__":
276
+ app.run(debug=True, host="0.0.0.0", port=7860)