Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,9 +9,11 @@ from PIL import Image
|
|
9 |
import base64
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
12 |
-
from transformers import ViTImageProcessor, SwinForImageClassification
|
13 |
import lightning as L
|
14 |
import uuid
|
|
|
|
|
15 |
# Set device
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
@@ -27,11 +29,7 @@ hyper_params = {
|
|
27 |
"label2id": label2id,
|
28 |
}
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
# Load the processor manually
|
33 |
-
vit_img_processor = AutoImageProcessor.from_pretrained('microsoft/swin-small-patch4-window7-224')
|
34 |
-
|
35 |
|
36 |
class DeepFakeModel(L.LightningModule):
|
37 |
def __init__(self, hyperparams: dict):
|
@@ -51,7 +49,7 @@ class DeepFakeModel(L.LightningModule):
|
|
51 |
|
52 |
# Load trained model
|
53 |
model = DeepFakeModel(hyper_params)
|
54 |
-
state_dict = torch.load("
|
55 |
model.load_state_dict(state_dict)
|
56 |
model.to(device)
|
57 |
model.eval()
|
@@ -84,8 +82,51 @@ def preprocess_image(img):
|
|
84 |
return img
|
85 |
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def predict_deepfake(image):
|
88 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
img_tensor = preprocess_image(image)
|
90 |
with torch.inference_mode():
|
91 |
logits = model(img_tensor)
|
@@ -104,17 +145,19 @@ def predict_deepfake(image):
|
|
104 |
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
|
105 |
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
118 |
|
119 |
|
120 |
@app.post("/api/analyze-base64", response_model=AnalysisResult)
|
@@ -134,4 +177,7 @@ async def root():
|
|
134 |
return {"message": "DeepFake Detector API is running"}
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
-
|
|
|
|
|
|
|
|
9 |
import base64
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
12 |
+
from transformers import ViTImageProcessor, SwinForImageClassification
|
13 |
import lightning as L
|
14 |
import uuid
|
15 |
+
import cv2
|
16 |
+
|
17 |
# Set device
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
|
|
|
29 |
"label2id": label2id,
|
30 |
}
|
31 |
|
32 |
+
vit_img_processor = ViTImageProcessor.from_pretrained(hyper_params['MODEL_CKPT'])
|
|
|
|
|
|
|
|
|
33 |
|
34 |
class DeepFakeModel(L.LightningModule):
|
35 |
def __init__(self, hyperparams: dict):
|
|
|
49 |
|
50 |
# Load trained model
|
51 |
model = DeepFakeModel(hyper_params)
|
52 |
+
state_dict = torch.load("trained_model.pth", map_location=torch.device(device))
|
53 |
model.load_state_dict(state_dict)
|
54 |
model.to(device)
|
55 |
model.eval()
|
|
|
82 |
return img
|
83 |
|
84 |
|
85 |
+
|
86 |
+
# Load the face detector once
|
87 |
+
face_net = cv2.dnn.readNetFromCaffe(
|
88 |
+
"deploy.prototxt",
|
89 |
+
"res10_300x300_ssd_iter_140000.caffemodel"
|
90 |
+
)
|
91 |
+
|
92 |
+
def detect_face_opencv(image: Image.Image) -> bool:
|
93 |
+
"""Detect face using OpenCV DNN"""
|
94 |
+
try:
|
95 |
+
# Convert PIL Image to OpenCV format
|
96 |
+
open_cv_image = np.array(image)
|
97 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy() # RGB to BGR
|
98 |
+
|
99 |
+
(h, w) = open_cv_image.shape[:2]
|
100 |
+
blob = cv2.dnn.blobFromImage(open_cv_image, 1.0, (300, 300),
|
101 |
+
(104.0, 177.0, 123.0))
|
102 |
+
|
103 |
+
face_net.setInput(blob)
|
104 |
+
detections = face_net.forward()
|
105 |
+
|
106 |
+
# Check if any detection has confidence > 0.5
|
107 |
+
for i in range(detections.shape[2]):
|
108 |
+
confidence = detections[0, 0, i, 2]
|
109 |
+
if confidence > 0.5:
|
110 |
+
return True # Face detected
|
111 |
+
|
112 |
+
return False # No face detected
|
113 |
+
except Exception as e:
|
114 |
+
print(f"Face detection error: {e}")
|
115 |
+
return False # Fail safe: assume no face
|
116 |
+
|
117 |
def predict_deepfake(image):
|
118 |
try:
|
119 |
+
# Step 1: Face Detection
|
120 |
+
has_face = detect_face_opencv(image)
|
121 |
+
if not has_face:
|
122 |
+
return {
|
123 |
+
"id": str(uuid.uuid4()),
|
124 |
+
"isDeepfake": None,
|
125 |
+
"confidence": 0.0,
|
126 |
+
"details": "No face detected in the image. Cannot proceed with deepfake analysis."
|
127 |
+
}
|
128 |
+
|
129 |
+
# Step 2: Deepfake Prediction (your original logic)
|
130 |
img_tensor = preprocess_image(image)
|
131 |
with torch.inference_mode():
|
132 |
logits = model(img_tensor)
|
|
|
145 |
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
|
146 |
|
147 |
|
148 |
+
|
149 |
+
|
150 |
+
# @app.post("/api/analyze", response_model=AnalysisResult)
|
151 |
+
# async def analyze_image(file: UploadFile = File(...)):
|
152 |
+
# if not file.content_type.startswith("image/"):
|
153 |
+
# raise HTTPException(status_code=400, detail="File must be an image")
|
154 |
+
# try:
|
155 |
+
# contents = await file.read()
|
156 |
+
# image = Image.open(io.BytesIO(contents)).convert("RGB")
|
157 |
+
# result = predict_deepfake(image)
|
158 |
+
# return JSONResponse(content=result)
|
159 |
+
# except Exception as e:
|
160 |
+
# raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
161 |
|
162 |
|
163 |
@app.post("/api/analyze-base64", response_model=AnalysisResult)
|
|
|
177 |
return {"message": "DeepFake Detector API is running"}
|
178 |
|
179 |
if __name__ == "__main__":
|
180 |
+
# Remove uvicorn.run for Hugging Face Spaces
|
181 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
182 |
+
|
183 |
+
|