pujanpaudel commited on
Commit
8ce99dc
·
verified ·
1 Parent(s): b840426

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -19
app.py CHANGED
@@ -9,9 +9,11 @@ from PIL import Image
9
  import base64
10
  import torch
11
  import torch.nn.functional as F
12
- from transformers import ViTImageProcessor, SwinForImageClassification,AutoImageProcessor
13
  import lightning as L
14
  import uuid
 
 
15
  # Set device
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
@@ -27,11 +29,7 @@ hyper_params = {
27
  "label2id": label2id,
28
  }
29
 
30
-
31
-
32
- # Load the processor manually
33
- vit_img_processor = AutoImageProcessor.from_pretrained('microsoft/swin-small-patch4-window7-224')
34
-
35
 
36
  class DeepFakeModel(L.LightningModule):
37
  def __init__(self, hyperparams: dict):
@@ -51,7 +49,7 @@ class DeepFakeModel(L.LightningModule):
51
 
52
  # Load trained model
53
  model = DeepFakeModel(hyper_params)
54
- state_dict = torch.load("deepfake_new_trained.pth", map_location=torch.device(device))
55
  model.load_state_dict(state_dict)
56
  model.to(device)
57
  model.eval()
@@ -84,8 +82,51 @@ def preprocess_image(img):
84
  return img
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def predict_deepfake(image):
88
  try:
 
 
 
 
 
 
 
 
 
 
 
89
  img_tensor = preprocess_image(image)
90
  with torch.inference_mode():
91
  logits = model(img_tensor)
@@ -104,17 +145,19 @@ def predict_deepfake(image):
104
  raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
105
 
106
 
107
- @app.post("/api/analyze", response_model=AnalysisResult)
108
- async def analyze_image(file: UploadFile = File(...)):
109
- if not file.content_type.startswith("image/"):
110
- raise HTTPException(status_code=400, detail="File must be an image")
111
- try:
112
- contents = await file.read()
113
- image = Image.open(io.BytesIO(contents)).convert("RGB")
114
- result = predict_deepfake(image)
115
- return JSONResponse(content=result)
116
- except Exception as e:
117
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
 
 
118
 
119
 
120
  @app.post("/api/analyze-base64", response_model=AnalysisResult)
@@ -134,4 +177,7 @@ async def root():
134
  return {"message": "DeepFake Detector API is running"}
135
 
136
  if __name__ == "__main__":
137
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
 
 
 
9
  import base64
10
  import torch
11
  import torch.nn.functional as F
12
+ from transformers import ViTImageProcessor, SwinForImageClassification
13
  import lightning as L
14
  import uuid
15
+ import cv2
16
+
17
  # Set device
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
 
29
  "label2id": label2id,
30
  }
31
 
32
+ vit_img_processor = ViTImageProcessor.from_pretrained(hyper_params['MODEL_CKPT'])
 
 
 
 
33
 
34
  class DeepFakeModel(L.LightningModule):
35
  def __init__(self, hyperparams: dict):
 
49
 
50
  # Load trained model
51
  model = DeepFakeModel(hyper_params)
52
+ state_dict = torch.load("trained_model.pth", map_location=torch.device(device))
53
  model.load_state_dict(state_dict)
54
  model.to(device)
55
  model.eval()
 
82
  return img
83
 
84
 
85
+
86
+ # Load the face detector once
87
+ face_net = cv2.dnn.readNetFromCaffe(
88
+ "deploy.prototxt",
89
+ "res10_300x300_ssd_iter_140000.caffemodel"
90
+ )
91
+
92
+ def detect_face_opencv(image: Image.Image) -> bool:
93
+ """Detect face using OpenCV DNN"""
94
+ try:
95
+ # Convert PIL Image to OpenCV format
96
+ open_cv_image = np.array(image)
97
+ open_cv_image = open_cv_image[:, :, ::-1].copy() # RGB to BGR
98
+
99
+ (h, w) = open_cv_image.shape[:2]
100
+ blob = cv2.dnn.blobFromImage(open_cv_image, 1.0, (300, 300),
101
+ (104.0, 177.0, 123.0))
102
+
103
+ face_net.setInput(blob)
104
+ detections = face_net.forward()
105
+
106
+ # Check if any detection has confidence > 0.5
107
+ for i in range(detections.shape[2]):
108
+ confidence = detections[0, 0, i, 2]
109
+ if confidence > 0.5:
110
+ return True # Face detected
111
+
112
+ return False # No face detected
113
+ except Exception as e:
114
+ print(f"Face detection error: {e}")
115
+ return False # Fail safe: assume no face
116
+
117
  def predict_deepfake(image):
118
  try:
119
+ # Step 1: Face Detection
120
+ has_face = detect_face_opencv(image)
121
+ if not has_face:
122
+ return {
123
+ "id": str(uuid.uuid4()),
124
+ "isDeepfake": None,
125
+ "confidence": 0.0,
126
+ "details": "No face detected in the image. Cannot proceed with deepfake analysis."
127
+ }
128
+
129
+ # Step 2: Deepfake Prediction (your original logic)
130
  img_tensor = preprocess_image(image)
131
  with torch.inference_mode():
132
  logits = model(img_tensor)
 
145
  raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
146
 
147
 
148
+
149
+
150
+ # @app.post("/api/analyze", response_model=AnalysisResult)
151
+ # async def analyze_image(file: UploadFile = File(...)):
152
+ # if not file.content_type.startswith("image/"):
153
+ # raise HTTPException(status_code=400, detail="File must be an image")
154
+ # try:
155
+ # contents = await file.read()
156
+ # image = Image.open(io.BytesIO(contents)).convert("RGB")
157
+ # result = predict_deepfake(image)
158
+ # return JSONResponse(content=result)
159
+ # except Exception as e:
160
+ # raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
161
 
162
 
163
  @app.post("/api/analyze-base64", response_model=AnalysisResult)
 
177
  return {"message": "DeepFake Detector API is running"}
178
 
179
  if __name__ == "__main__":
180
+ # Remove uvicorn.run for Hugging Face Spaces
181
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
182
+
183
+