Spaces:
Sleeping
Sleeping
import os | |
os.environ["MPLCONFIGDIR"] = "/home/user/app" | |
from flask import Flask, request, jsonify | |
import torch | |
import torch.nn.functional as F | |
from facenet_pytorch import MTCNN, InceptionResnetV1 | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from io import BytesIO | |
app = Flask(__name__) | |
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
mtcnn = MTCNN( | |
select_largest=False, | |
post_process=False, | |
device=DEVICE | |
).to(DEVICE).eval() | |
MODEL_PATH = "20180402-114759-vggface2.pt" # Update this with the path to your model file | |
model = InceptionResnetV1( | |
pretrained=None, # Set pretrained to None since you're loading from a local file | |
classify=True, | |
num_classes=1, | |
device=DEVICE | |
) | |
# Load the pretrained weights from your local file | |
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu')) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(DEVICE) | |
model.eval() | |
def predict(image_bytes): | |
"""Predict the label of the input_image""" | |
true_label = request.form['true_label'] | |
image = Image.open(BytesIO(image_bytes)) | |
face = mtcnn(image) | |
if face is None: | |
return jsonify({"error": "No face detected"}), 400 | |
face = face.unsqueeze(0) # add the batch dimension | |
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) | |
face = face.to(DEVICE) | |
face = face.to(torch.float32) | |
face = face / 255.0 | |
target_layers = [model.block8.branch1[-1]] | |
cam = GradCAM(model=model, target_layers=target_layers) | |
targets = [ClassifierOutputTarget(0)] | |
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) | |
grayscale_cam = grayscale_cam[0, :] | |
visualization = show_cam_on_image(face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), grayscale_cam, use_rgb=True) | |
face_with_mask = cv2.addWeighted(face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy().astype('uint8'), 1, visualization, 0.5, 0) | |
with torch.no_grad(): | |
output = torch.sigmoid(model(face).squeeze(0)) | |
prediction = "real" if output.item() < 0.5 else "fake" | |
real_prediction = 1 - output.item() | |
fake_prediction = output.item() | |
confidences = { | |
'real': real_prediction, | |
'fake': fake_prediction | |
} | |
return jsonify({ | |
'confidences': confidences, | |
'true_label': true_label, | |
'prediction': prediction, | |
'face_with_mask': face_with_mask.tolist() # Convert numpy array to list for JSON serialization | |
}) | |
def index(): | |
return "Welcome to DeepFake Detection API" | |
def predict_endpoint(): | |
if 'file' not in request.files: | |
return jsonify({"error": "No file part"}), 400 | |
file = request.files['file'] | |
if file.filename == '': | |
return jsonify({"error": "No selected file"}), 400 | |
return predict(file.read()) | |
if __name__ == "__main__": | |
app.run(debug=True, port=7860) # Change the port to 7860 | |