[email protected] commited on
Commit
edd48c9
·
1 Parent(s): d2207c4
Files changed (4) hide show
  1. Dockerfile +11 -0
  2. app.py +90 -0
  3. requirements.txt +7 -0
  4. resnetinceptionv1_epoch_32.pth +3 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "-b", "0.0.0.0:7860", "main:app"]
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import FastAPI, UploadFile, Form, File
3
+ from fastapi.responses import JSONResponse
4
+ from PIL import Image
5
+ import io
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from facenet_pytorch import MTCNN, InceptionResnetV1
9
+ import numpy as np
10
+ import cv2
11
+ from pytorch_grad_cam import GradCAM
12
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
+ from pytorch_grad_cam.utils.image import show_cam_on_image
14
+
15
+ app = FastAPI()
16
+
17
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
18
+
19
+ mtcnn = MTCNN(
20
+ select_largest=False,
21
+ post_process=False,
22
+ device=DEVICE
23
+ ).to(DEVICE).eval()
24
+
25
+ model = InceptionResnetV1(
26
+ pretrained="vggface2",
27
+ classify=True,
28
+ num_classes=1,
29
+ device=DEVICE
30
+ )
31
+
32
+ checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
33
+ model.load_state_dict(checkpoint['model_state_dict'])
34
+ model.to(DEVICE)
35
+ model.eval()
36
+
37
+ @app.get("/")
38
+ async def read_root():
39
+ return {"message": "Welcome to DeepFake Detection API"}
40
+
41
+ @app.post("/predict/")
42
+ async def predict(image: UploadFile = File(...), true_label: str = Form(...)):
43
+ try:
44
+ contents = await image.read()
45
+ pil_image = Image.open(io.BytesIO(contents))
46
+
47
+ face = mtcnn(pil_image)
48
+ if face is None:
49
+ return JSONResponse(status_code=400, content={"message": "No face detected"})
50
+
51
+ face = face.unsqueeze(0)
52
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
53
+ face = face.to(DEVICE, dtype=torch.float32) / 255.0
54
+
55
+ target_layers=[model.block8.branch1[-1]]
56
+ cam = GradCAM(model=model, target_layers=target_layers)
57
+ targets = [ClassifierOutputTarget(0)]
58
+
59
+ grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
60
+ grayscale_cam = grayscale_cam[0, :]
61
+ visualization = show_cam_on_image(face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), grayscale_cam, use_rgb=True)
62
+ face_with_mask = cv2.addWeighted(face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy().astype('uint8'), 1, visualization, 0.5, 0)
63
+
64
+ with torch.no_grad():
65
+ output = torch.sigmoid(model(face).squeeze(0))
66
+ prediction = "real" if output.item() < 0.5 else "fake"
67
+
68
+ real_prediction = 1 - output.item()
69
+ fake_prediction = output.item()
70
+
71
+ confidences = {
72
+ 'real': real_prediction,
73
+ 'fake': fake_prediction
74
+ }
75
+
76
+ # Determine final prediction based on confidence scores
77
+ final_prediction = "real" if real_prediction > fake_prediction else "fake"
78
+
79
+ return {
80
+ 'confidences': confidences,
81
+ 'true_label': true_label,
82
+ 'final_prediction': final_prediction,
83
+ 'face_with_mask': face_with_mask.tolist() # Convert numpy array to list for JSON serialization
84
+ }
85
+ except Exception as e:
86
+ return JSONResponse(status_code=500, content={"message": str(e)})
87
+
88
+ if __name__ == "__main__":
89
+ import uvicorn
90
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ Pillow
3
+ facenet-pytorch
4
+ torch
5
+ opencv-python
6
+ grad-cam
7
+ fastapi
resnetinceptionv1_epoch_32.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:794ebe83c6a7d7959c30c175030b4885e2b9fa175f1cc3e582236595d119f52b
3
+ size 282395989