codewithRiz's picture
Update app.py
6ca6607 verified
from IPython.display import display, JSON
import matplotlib.pyplot as plt
from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet
import numpy as np
import time
import gradio as gr
import json
import cv2
import os
# ------------------------------------------------------
# LOAD MODEL
# ------------------------------------------------------
print("Default SpeciesNet model:", DEFAULT_MODEL)
print("Supported SpeciesNet models:", SUPPORTED_MODELS)
model = SpeciesNet(DEFAULT_MODEL)
# ------------------------------------------------------
# VALIDATION FUNCTIONS
# ------------------------------------------------------
def validate_predictions_structure(pred):
"""
Validate internal structure for both detection and classification.
This ensures correct keys exist and formats are valid.
"""
required_keys = ["filepath", "detections", "classifications"]
for key in required_keys:
if key not in pred:
raise ValueError(f" Missing key '{key}' in prediction block")
# --- Validate detections (list of dicts) ---
if not isinstance(pred["detections"], list):
raise ValueError(" detections must be a list")
for det in pred["detections"]:
if not all(k in det for k in ["bbox", "conf", "label"]):
raise ValueError(" Each detection must contain bbox, conf, label")
if len(det["bbox"]) != 4:
raise ValueError(" bbox must be [x, y, w, h]")
# --- Validate classifications ---
cls = pred["classifications"]
if not isinstance(cls, dict):
raise ValueError(" classifications must be a dictionary")
for key in ["classes", "scores"]:
if key not in cls:
raise ValueError(f" classifications missing '{key}'")
if len(cls["classes"]) != len(cls["scores"]):
raise ValueError(" classes and scores length mismatch")
return True
def validate_model_output(predictions_dict):
"""
Validates entire output returned by SpeciesNet before visualization.
"""
if "predictions" not in predictions_dict:
raise ValueError(" Output missing top-level 'predictions' key")
if not isinstance(predictions_dict["predictions"], list):
raise ValueError(" 'predictions' must be a list")
print(f" Total prediction entries: {len(predictions_dict['predictions'])}")
# Validate each prediction block
for i, pred in enumerate(predictions_dict["predictions"]):
print(f"\n--- Checking prediction #{i+1} ---")
validate_predictions_structure(pred)
print("\n Output format validated successfully!\n")
# ------------------------------------------------------
# VISUALIZATION
# ------------------------------------------------------
def draw_predictions(image_path, predictions_dict):
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Could not load image: {image_path}")
img_h, img_w, _ = img.shape
for pred in predictions_dict.get("predictions", []):
detections = pred.get("detections", [])
classifications = pred.get("classifications", {})
classes = classifications.get("classes", [])
scores = classifications.get("scores", [])
top_class_name = None
top_score = None
if len(classes) > 0:
top_class_name = classes[0].split(";")[-1]
top_score = scores[0]
# SKIP NON-ANIMALS
if len(classes) == 0:
continue
taxon = classes[0].lower()
if not ("mammalia" in taxon or "aves" in taxon):
continue
for det in detections:
bbox = det["bbox"]
conf = det["conf"]
label = det["label"]
x, y, w, h = bbox
x1 = int(x * img_w)
y1 = int(y * img_h)
x2 = int((x + w) * img_w)
y2 = int((y + h) * img_h)
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 3)
detection_text = f"{label} ({conf:.2f})"
classification_text = (
f"{top_class_name} ({top_score:.2f})" if top_class_name else ""
)
text_lines = []
if classification_text:
text_lines.append(classification_text)
text_lines.append(detection_text)
total_text_height = 0
text_widths = []
for line in text_lines:
(text_w, text_h), _ = cv2.getTextSize(
line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2
)
total_text_height += text_h + 5
text_widths.append(text_w)
max_text_width = max(text_widths)
cv2.rectangle(
img,
(x1, max(y1 - total_text_height - 10, 0)),
(x1 + max_text_width + 10, y1),
(0, 255, 0),
-1,
)
y_text = y1 - 5
for line in text_lines[::-1]:
cv2.putText(
img,
line,
(x1 + 5, y_text),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 0, 0),
2,
cv2.LINE_AA,
)
(_, text_h), _ = cv2.getTextSize(
line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2
)
y_text -= text_h + 5
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------------------------------------------
# INFERENCE FUNCTION
# ------------------------------------------------------
def inference(image):
filepath = "temp_image.jpg"
image.save(filepath)
start = time.time()
predictions_dict = model.predict(
instances_dict={
"instances": [
{
"filepath": filepath,
# "country": "VNM",
}
]
}
)
end = time.time()
print(f"\n⏱ Inference Time: {end - start:.2f} sec")
# --- Validate format ---
validate_model_output(predictions_dict)
# --- Save JSON ---
with open("last_output.json", "w") as f:
json.dump(predictions_dict, f, indent=4)
print(" Saved JSON to last_output.json\n")
# --- Draw Visualization ---
annotated_image = draw_predictions(filepath, predictions_dict)
pretty_json = json.dumps(predictions_dict, indent=4)
return annotated_image, pretty_json
# ------------------------------------------------------
# GRADIO UI
# ------------------------------------------------------
iface = gr.Interface(
fn=inference,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(label="Detection + Classification Output"),
gr.JSON(label="Raw Model Output"),
],
title=" SpeciesNet Wildlife Detector + Classifier",
description="Upload a wildlife camera image.",
)
iface.launch()