|
|
from IPython.display import display, JSON |
|
|
import matplotlib.pyplot as plt |
|
|
from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet |
|
|
import numpy as np |
|
|
import time |
|
|
import gradio as gr |
|
|
import json |
|
|
import cv2 |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Default SpeciesNet model:", DEFAULT_MODEL) |
|
|
print("Supported SpeciesNet models:", SUPPORTED_MODELS) |
|
|
model = SpeciesNet(DEFAULT_MODEL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_predictions_structure(pred): |
|
|
""" |
|
|
Validate internal structure for both detection and classification. |
|
|
This ensures correct keys exist and formats are valid. |
|
|
""" |
|
|
|
|
|
required_keys = ["filepath", "detections", "classifications"] |
|
|
|
|
|
for key in required_keys: |
|
|
if key not in pred: |
|
|
raise ValueError(f" Missing key '{key}' in prediction block") |
|
|
|
|
|
|
|
|
if not isinstance(pred["detections"], list): |
|
|
raise ValueError(" detections must be a list") |
|
|
|
|
|
for det in pred["detections"]: |
|
|
if not all(k in det for k in ["bbox", "conf", "label"]): |
|
|
raise ValueError(" Each detection must contain bbox, conf, label") |
|
|
|
|
|
if len(det["bbox"]) != 4: |
|
|
raise ValueError(" bbox must be [x, y, w, h]") |
|
|
|
|
|
|
|
|
cls = pred["classifications"] |
|
|
if not isinstance(cls, dict): |
|
|
raise ValueError(" classifications must be a dictionary") |
|
|
|
|
|
for key in ["classes", "scores"]: |
|
|
if key not in cls: |
|
|
raise ValueError(f" classifications missing '{key}'") |
|
|
|
|
|
if len(cls["classes"]) != len(cls["scores"]): |
|
|
raise ValueError(" classes and scores length mismatch") |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def validate_model_output(predictions_dict): |
|
|
""" |
|
|
Validates entire output returned by SpeciesNet before visualization. |
|
|
""" |
|
|
|
|
|
if "predictions" not in predictions_dict: |
|
|
raise ValueError(" Output missing top-level 'predictions' key") |
|
|
|
|
|
if not isinstance(predictions_dict["predictions"], list): |
|
|
raise ValueError(" 'predictions' must be a list") |
|
|
|
|
|
print(f" Total prediction entries: {len(predictions_dict['predictions'])}") |
|
|
|
|
|
|
|
|
for i, pred in enumerate(predictions_dict["predictions"]): |
|
|
print(f"\n--- Checking prediction #{i+1} ---") |
|
|
validate_predictions_structure(pred) |
|
|
|
|
|
print("\n Output format validated successfully!\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_predictions(image_path, predictions_dict): |
|
|
|
|
|
img = cv2.imread(image_path) |
|
|
if img is None: |
|
|
raise ValueError(f"Could not load image: {image_path}") |
|
|
|
|
|
img_h, img_w, _ = img.shape |
|
|
|
|
|
for pred in predictions_dict.get("predictions", []): |
|
|
detections = pred.get("detections", []) |
|
|
classifications = pred.get("classifications", {}) |
|
|
|
|
|
classes = classifications.get("classes", []) |
|
|
scores = classifications.get("scores", []) |
|
|
|
|
|
top_class_name = None |
|
|
top_score = None |
|
|
|
|
|
if len(classes) > 0: |
|
|
top_class_name = classes[0].split(";")[-1] |
|
|
top_score = scores[0] |
|
|
|
|
|
|
|
|
if len(classes) == 0: |
|
|
continue |
|
|
|
|
|
taxon = classes[0].lower() |
|
|
|
|
|
if not ("mammalia" in taxon or "aves" in taxon): |
|
|
continue |
|
|
|
|
|
for det in detections: |
|
|
bbox = det["bbox"] |
|
|
conf = det["conf"] |
|
|
label = det["label"] |
|
|
|
|
|
x, y, w, h = bbox |
|
|
x1 = int(x * img_w) |
|
|
y1 = int(y * img_h) |
|
|
x2 = int((x + w) * img_w) |
|
|
y2 = int((y + h) * img_h) |
|
|
|
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 3) |
|
|
|
|
|
detection_text = f"{label} ({conf:.2f})" |
|
|
classification_text = ( |
|
|
f"{top_class_name} ({top_score:.2f})" if top_class_name else "" |
|
|
) |
|
|
|
|
|
text_lines = [] |
|
|
if classification_text: |
|
|
text_lines.append(classification_text) |
|
|
text_lines.append(detection_text) |
|
|
|
|
|
total_text_height = 0 |
|
|
text_widths = [] |
|
|
|
|
|
for line in text_lines: |
|
|
(text_w, text_h), _ = cv2.getTextSize( |
|
|
line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2 |
|
|
) |
|
|
total_text_height += text_h + 5 |
|
|
text_widths.append(text_w) |
|
|
|
|
|
max_text_width = max(text_widths) |
|
|
|
|
|
cv2.rectangle( |
|
|
img, |
|
|
(x1, max(y1 - total_text_height - 10, 0)), |
|
|
(x1 + max_text_width + 10, y1), |
|
|
(0, 255, 0), |
|
|
-1, |
|
|
) |
|
|
|
|
|
y_text = y1 - 5 |
|
|
for line in text_lines[::-1]: |
|
|
cv2.putText( |
|
|
img, |
|
|
line, |
|
|
(x1 + 5, y_text), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.6, |
|
|
(0, 0, 0), |
|
|
2, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
(_, text_h), _ = cv2.getTextSize( |
|
|
line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2 |
|
|
) |
|
|
y_text -= text_h + 5 |
|
|
|
|
|
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(image): |
|
|
|
|
|
filepath = "temp_image.jpg" |
|
|
image.save(filepath) |
|
|
|
|
|
start = time.time() |
|
|
predictions_dict = model.predict( |
|
|
instances_dict={ |
|
|
"instances": [ |
|
|
{ |
|
|
"filepath": filepath, |
|
|
|
|
|
} |
|
|
] |
|
|
} |
|
|
) |
|
|
end = time.time() |
|
|
|
|
|
print(f"\n⏱ Inference Time: {end - start:.2f} sec") |
|
|
|
|
|
|
|
|
validate_model_output(predictions_dict) |
|
|
|
|
|
|
|
|
with open("last_output.json", "w") as f: |
|
|
json.dump(predictions_dict, f, indent=4) |
|
|
|
|
|
print(" Saved JSON to last_output.json\n") |
|
|
|
|
|
|
|
|
annotated_image = draw_predictions(filepath, predictions_dict) |
|
|
|
|
|
pretty_json = json.dumps(predictions_dict, indent=4) |
|
|
|
|
|
return annotated_image, pretty_json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=inference, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=[ |
|
|
gr.Image(label="Detection + Classification Output"), |
|
|
gr.JSON(label="Raw Model Output"), |
|
|
], |
|
|
title=" SpeciesNet Wildlife Detector + Classifier", |
|
|
description="Upload a wildlife camera image.", |
|
|
) |
|
|
|
|
|
iface.launch() |
|
|
|