Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -5,13 +5,13 @@ from flask import Flask, request, jsonify
|
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from facenet_pytorch import MTCNN, InceptionResnetV1
|
8 |
-
import os
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
import cv2
|
12 |
from pytorch_grad_cam import GradCAM
|
13 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
14 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
|
15 |
|
16 |
app = Flask(__name__)
|
17 |
|
@@ -39,23 +39,11 @@ model.load_state_dict(checkpoint['model_state_dict'])
|
|
39 |
model.to(DEVICE)
|
40 |
model.eval()
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
def predict(input_image):
|
45 |
"""Predict the label of the input_image"""
|
46 |
true_label = request.form['true_label']
|
47 |
-
|
48 |
-
|
49 |
-
file = request.files['file']
|
50 |
-
if file.filename == '':
|
51 |
-
return jsonify({"error": "No selected file"}), 400
|
52 |
-
|
53 |
-
# Save the uploaded image to the uploads folder
|
54 |
-
input_image_path = os.path.join(UPLOAD_FOLDER, file.filename)
|
55 |
-
file.save(input_image_path)
|
56 |
-
|
57 |
-
input_image = Image.open(input_image_path)
|
58 |
-
face = mtcnn(input_image)
|
59 |
if face is None:
|
60 |
return jsonify({"error": "No face detected"}), 400
|
61 |
|
@@ -99,7 +87,12 @@ def index():
|
|
99 |
|
100 |
@app.route("/predict", methods=['POST'])
|
101 |
def predict_endpoint():
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
if __name__ == "__main__":
|
105 |
app.run(debug=True, port=7860) # Change the port to 7860
|
|
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from facenet_pytorch import MTCNN, InceptionResnetV1
|
|
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
import cv2
|
11 |
from pytorch_grad_cam import GradCAM
|
12 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
13 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
14 |
+
from io import BytesIO
|
15 |
|
16 |
app = Flask(__name__)
|
17 |
|
|
|
39 |
model.to(DEVICE)
|
40 |
model.eval()
|
41 |
|
42 |
+
def predict(image_bytes):
|
|
|
|
|
43 |
"""Predict the label of the input_image"""
|
44 |
true_label = request.form['true_label']
|
45 |
+
image = Image.open(BytesIO(image_bytes))
|
46 |
+
face = mtcnn(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
if face is None:
|
48 |
return jsonify({"error": "No face detected"}), 400
|
49 |
|
|
|
87 |
|
88 |
@app.route("/predict", methods=['POST'])
|
89 |
def predict_endpoint():
|
90 |
+
if 'file' not in request.files:
|
91 |
+
return jsonify({"error": "No file part"}), 400
|
92 |
+
file = request.files['file']
|
93 |
+
if file.filename == '':
|
94 |
+
return jsonify({"error": "No selected file"}), 400
|
95 |
+
return predict(file.read())
|
96 |
|
97 |
if __name__ == "__main__":
|
98 |
app.run(debug=True, port=7860) # Change the port to 7860
|