achisingh06 commited on
Commit
66453cf
·
verified ·
1 Parent(s): 675368d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -17
main.py CHANGED
@@ -5,13 +5,13 @@ from flask import Flask, request, jsonify
5
  import torch
6
  import torch.nn.functional as F
7
  from facenet_pytorch import MTCNN, InceptionResnetV1
8
- import os
9
  import numpy as np
10
  from PIL import Image
11
  import cv2
12
  from pytorch_grad_cam import GradCAM
13
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
14
  from pytorch_grad_cam.utils.image import show_cam_on_image
 
15
 
16
  app = Flask(__name__)
17
 
@@ -39,23 +39,11 @@ model.load_state_dict(checkpoint['model_state_dict'])
39
  model.to(DEVICE)
40
  model.eval()
41
 
42
- UPLOAD_FOLDER = os.getcwd() # Save the uploaded image to the current working directory
43
-
44
- def predict(input_image):
45
  """Predict the label of the input_image"""
46
  true_label = request.form['true_label']
47
- if 'file' not in request.files:
48
- return jsonify({"error": "No file part"}), 400
49
- file = request.files['file']
50
- if file.filename == '':
51
- return jsonify({"error": "No selected file"}), 400
52
-
53
- # Save the uploaded image to the uploads folder
54
- input_image_path = os.path.join(UPLOAD_FOLDER, file.filename)
55
- file.save(input_image_path)
56
-
57
- input_image = Image.open(input_image_path)
58
- face = mtcnn(input_image)
59
  if face is None:
60
  return jsonify({"error": "No face detected"}), 400
61
 
@@ -99,7 +87,12 @@ def index():
99
 
100
  @app.route("/predict", methods=['POST'])
101
  def predict_endpoint():
102
- return predict(request)
 
 
 
 
 
103
 
104
  if __name__ == "__main__":
105
  app.run(debug=True, port=7860) # Change the port to 7860
 
5
  import torch
6
  import torch.nn.functional as F
7
  from facenet_pytorch import MTCNN, InceptionResnetV1
 
8
  import numpy as np
9
  from PIL import Image
10
  import cv2
11
  from pytorch_grad_cam import GradCAM
12
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
  from pytorch_grad_cam.utils.image import show_cam_on_image
14
+ from io import BytesIO
15
 
16
  app = Flask(__name__)
17
 
 
39
  model.to(DEVICE)
40
  model.eval()
41
 
42
+ def predict(image_bytes):
 
 
43
  """Predict the label of the input_image"""
44
  true_label = request.form['true_label']
45
+ image = Image.open(BytesIO(image_bytes))
46
+ face = mtcnn(image)
 
 
 
 
 
 
 
 
 
 
47
  if face is None:
48
  return jsonify({"error": "No face detected"}), 400
49
 
 
87
 
88
  @app.route("/predict", methods=['POST'])
89
  def predict_endpoint():
90
+ if 'file' not in request.files:
91
+ return jsonify({"error": "No file part"}), 400
92
+ file = request.files['file']
93
+ if file.filename == '':
94
+ return jsonify({"error": "No selected file"}), 400
95
+ return predict(file.read())
96
 
97
  if __name__ == "__main__":
98
  app.run(debug=True, port=7860) # Change the port to 7860