Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, render_template | |
from detectron2.config import get_cfg | |
from detectron2.engine import DefaultPredictor | |
from detectron2.data import MetadataCatalog | |
from detectron2.utils.visualizer import Visualizer, ColorMode | |
import numpy as np | |
from PIL import Image | |
import io | |
import os | |
import requests | |
import gdown | |
from skimage import io as skio | |
from torchvision.ops import box_iou | |
import torch | |
# Initialize Flask app | |
app = Flask(__name__) | |
cfg = None | |
# Google Drive file URL | |
# Replace 'your-file-id' with the actual file ID from Google Drive | |
GDRIVE_MODEL_URL = "https://drive.google.com/uc?id=1fzKneepaRt_--dzamTcDBM-9d3_dLX7z" | |
LOCAL_MODEL_PATH = "model_final.pth" | |
def download_file_from_google_drive(id, destination): | |
gdown.download(GDRIVE_MODEL_URL, LOCAL_MODEL_PATH, quiet=False) | |
file_id = "1fzKneepaRt_--dzamTcDBM-9d3_dLX7z" | |
destination = "checkpoint32.pth" | |
download_file_from_google_drive(file_id, destination) | |
# Download model from Google Drive if not already present locally | |
def download_model(): | |
if not os.path.exists(LOCAL_MODEL_PATH): | |
response = requests.get(GDRIVE_MODEL_URL, stream=True) | |
if response.status_code == 200: | |
with open(LOCAL_MODEL_PATH, "wb") as f: | |
f.write(response.content) | |
else: | |
raise Exception( | |
f"Failed to download model from Google Drive: {response.status_code}" | |
) | |
# Configuration and model setup | |
def setup_model(model_path): | |
global cfg | |
cfg = get_cfg() | |
cfg.merge_from_file("config.yaml") # Update with the config file path | |
cfg.MODEL.WEIGHTS = model_path | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
cfg.MODEL.DEVICE = "cpu" # Use "cuda" for GPU | |
return DefaultPredictor(cfg) | |
# Ensure model is available | |
predictor = setup_model(LOCAL_MODEL_PATH) | |
# Define expected parts and costs | |
expected_parts = ["headlamp", "rear_bumper", "door", "hood", "front_bumper"] | |
cost_dict = { | |
"headlamp": 300, | |
"rear_bumper": 250, | |
"door": 200, | |
"hood": 220, | |
"front_bumper": 250, | |
"other": 150, | |
} | |
def home(): | |
return render_template("index.html") | |
def upload(): | |
if "file" not in request.files: | |
return jsonify({"error": "No file uploaded"}), 400 | |
file = request.files["file"] | |
if file.filename == "": | |
return jsonify({"error": "No file selected"}), 400 | |
# Load image | |
image = skio.imread(file) | |
image_np = image | |
# Run model prediction | |
outputs = predictor(image_np) | |
instances = outputs["instances"].to("cpu") | |
class_names = MetadataCatalog.get(cfg.DATASETS.TEST[0]).thing_classes | |
# Extract bounding boxes and class IDs | |
boxes = instances.pred_boxes.tensor.numpy() | |
class_ids = instances.pred_classes.numpy() | |
# Filter overlapping boxes using IoU | |
iou_threshold = 0.8 | |
keep_indices = [] | |
merged_boxes = set() | |
for i in range(len(boxes)): | |
if i in merged_boxes: | |
continue | |
keep_indices.append(i) | |
for j in range(i + 1, len(boxes)): | |
if j in merged_boxes: | |
continue | |
iou = box_iou( | |
torch.tensor(boxes[i]).unsqueeze( | |
0), torch.tensor(boxes[j]).unsqueeze(0) | |
).item() | |
if iou > iou_threshold: | |
merged_boxes.add(j) | |
# Calculate total cost based on non-overlapping boxes | |
total_cost = 0 | |
damage_details = [] | |
for idx in keep_indices: | |
class_id = class_ids[idx] | |
damaged_part = ( | |
class_names[class_id] if class_id < len(class_names) else "unknown" | |
) | |
if damaged_part not in expected_parts: | |
damaged_part = "other" | |
repair_cost = cost_dict.get(damaged_part, cost_dict["other"]) | |
total_cost += repair_cost | |
damage_details.append({"part": damaged_part, "cost_usd": repair_cost}) | |
response = {"damages": damage_details, "total_cost": total_cost} | |
return jsonify(response) | |
def fetchImage(): | |
file = None | |
if "url" in request.form: | |
url = request.form["url"] | |
response = requests.get(url) | |
file = io.BytesIO(response.content) | |
elif "file" in request.files: | |
file = request.files["file"] | |
# Load image | |
image = skio.imread(file) | |
image_np = image | |
return jsonify(response) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |