deepscanAPI / app.py
kautilya286's picture
first commit
b5b2f19
import os
import torch
import numpy as np
import joblib
from PIL import Image
from flask import Flask, request, jsonify
from transformers import CLIPProcessor, CLIPModel
from io import BytesIO
from flask_cors import CORS
import base64
import io
# Flask app initialization
app = Flask(__name__)
CORS(app)
# Load models once at the start
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load the ensemble classifier model
ensemble_clf = joblib.load("model/random_forest_tuned_aug.pkl")
# Label mapping
label_map = {0: "real", 1: "deepfake", 2: "ai_gen"}
def extract_features(image):
image = image.resize((224, 224)) # Resize to the required input size (224x224)
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
# Extract image features using CLIP
outputs = model.get_image_features(**inputs)
emb = outputs.cpu().numpy().squeeze()
return emb
@app.route("/predict", methods=["POST"])
def predict():
# Get the uploaded image
data = request.json
if 'image' not in data:
return jsonify({"error": "No image provided"}), 400
image_data = base64.b64decode(data['image'])
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Extract features and predict
features = extract_features(image)
probs = ensemble_clf.predict_proba([features])[0]
top_idx = np.argmax(probs)
# Prepare response
response = {
"prediction": label_map[top_idx],
"probabilities": probs.tolist()
}
return jsonify(response)
if __name__ == "__main__":
# Run Flask app
app.run(debug=True)