cyberhack / main.py
achisingh06's picture
Update main.py
66453cf verified
import os
os.environ["MPLCONFIGDIR"] = "/home/user/app"
from flask import Flask, request, jsonify
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import numpy as np
from PIL import Image
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from io import BytesIO
app = Flask(__name__)
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
mtcnn = MTCNN(
select_largest=False,
post_process=False,
device=DEVICE
).to(DEVICE).eval()
MODEL_PATH = "20180402-114759-vggface2.pt" # Update this with the path to your model file
model = InceptionResnetV1(
pretrained=None, # Set pretrained to None since you're loading from a local file
classify=True,
num_classes=1,
device=DEVICE
)
# Load the pretrained weights from your local file
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
def predict(image_bytes):
"""Predict the label of the input_image"""
true_label = request.form['true_label']
image = Image.open(BytesIO(image_bytes))
face = mtcnn(image)
if face is None:
return jsonify({"error": "No face detected"}), 400
face = face.unsqueeze(0) # add the batch dimension
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
face = face.to(DEVICE)
face = face.to(torch.float32)
face = face / 255.0
target_layers = [model.block8.branch1[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(0)]
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), grayscale_cam, use_rgb=True)
face_with_mask = cv2.addWeighted(face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy().astype('uint8'), 1, visualization, 0.5, 0)
with torch.no_grad():
output = torch.sigmoid(model(face).squeeze(0))
prediction = "real" if output.item() < 0.5 else "fake"
real_prediction = 1 - output.item()
fake_prediction = output.item()
confidences = {
'real': real_prediction,
'fake': fake_prediction
}
return jsonify({
'confidences': confidences,
'true_label': true_label,
'prediction': prediction,
'face_with_mask': face_with_mask.tolist() # Convert numpy array to list for JSON serialization
})
@app.route("/")
def index():
return "Welcome to DeepFake Detection API"
@app.route("/predict", methods=['POST'])
def predict_endpoint():
if 'file' not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
return predict(file.read())
if __name__ == "__main__":
app.run(debug=True, port=7860) # Change the port to 7860