from IPython.display import display, JSON import matplotlib.pyplot as plt from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet import numpy as np import time import gradio as gr import json import cv2 import os # ------------------------------------------------------ # LOAD MODEL # ------------------------------------------------------ print("Default SpeciesNet model:", DEFAULT_MODEL) print("Supported SpeciesNet models:", SUPPORTED_MODELS) model = SpeciesNet(DEFAULT_MODEL) # ------------------------------------------------------ # VALIDATION FUNCTIONS # ------------------------------------------------------ def validate_predictions_structure(pred): """ Validate internal structure for both detection and classification. This ensures correct keys exist and formats are valid. """ required_keys = ["filepath", "detections", "classifications"] for key in required_keys: if key not in pred: raise ValueError(f" Missing key '{key}' in prediction block") # --- Validate detections (list of dicts) --- if not isinstance(pred["detections"], list): raise ValueError(" detections must be a list") for det in pred["detections"]: if not all(k in det for k in ["bbox", "conf", "label"]): raise ValueError(" Each detection must contain bbox, conf, label") if len(det["bbox"]) != 4: raise ValueError(" bbox must be [x, y, w, h]") # --- Validate classifications --- cls = pred["classifications"] if not isinstance(cls, dict): raise ValueError(" classifications must be a dictionary") for key in ["classes", "scores"]: if key not in cls: raise ValueError(f" classifications missing '{key}'") if len(cls["classes"]) != len(cls["scores"]): raise ValueError(" classes and scores length mismatch") return True def validate_model_output(predictions_dict): """ Validates entire output returned by SpeciesNet before visualization. """ if "predictions" not in predictions_dict: raise ValueError(" Output missing top-level 'predictions' key") if not isinstance(predictions_dict["predictions"], list): raise ValueError(" 'predictions' must be a list") print(f" Total prediction entries: {len(predictions_dict['predictions'])}") # Validate each prediction block for i, pred in enumerate(predictions_dict["predictions"]): print(f"\n--- Checking prediction #{i+1} ---") validate_predictions_structure(pred) print("\n Output format validated successfully!\n") # ------------------------------------------------------ # VISUALIZATION # ------------------------------------------------------ def draw_predictions(image_path, predictions_dict): img = cv2.imread(image_path) if img is None: raise ValueError(f"Could not load image: {image_path}") img_h, img_w, _ = img.shape for pred in predictions_dict.get("predictions", []): detections = pred.get("detections", []) classifications = pred.get("classifications", {}) classes = classifications.get("classes", []) scores = classifications.get("scores", []) top_class_name = None top_score = None if len(classes) > 0: top_class_name = classes[0].split(";")[-1] top_score = scores[0] # SKIP NON-ANIMALS if len(classes) == 0: continue taxon = classes[0].lower() if not ("mammalia" in taxon or "aves" in taxon): continue for det in detections: bbox = det["bbox"] conf = det["conf"] label = det["label"] x, y, w, h = bbox x1 = int(x * img_w) y1 = int(y * img_h) x2 = int((x + w) * img_w) y2 = int((y + h) * img_h) cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 3) detection_text = f"{label} ({conf:.2f})" classification_text = ( f"{top_class_name} ({top_score:.2f})" if top_class_name else "" ) text_lines = [] if classification_text: text_lines.append(classification_text) text_lines.append(detection_text) total_text_height = 0 text_widths = [] for line in text_lines: (text_w, text_h), _ = cv2.getTextSize( line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2 ) total_text_height += text_h + 5 text_widths.append(text_w) max_text_width = max(text_widths) cv2.rectangle( img, (x1, max(y1 - total_text_height - 10, 0)), (x1 + max_text_width + 10, y1), (0, 255, 0), -1, ) y_text = y1 - 5 for line in text_lines[::-1]: cv2.putText( img, line, (x1 + 5, y_text), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA, ) (_, text_h), _ = cv2.getTextSize( line, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2 ) y_text -= text_h + 5 return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # ------------------------------------------------------ # INFERENCE FUNCTION # ------------------------------------------------------ def inference(image): filepath = "temp_image.jpg" image.save(filepath) start = time.time() predictions_dict = model.predict( instances_dict={ "instances": [ { "filepath": filepath, # "country": "VNM", } ] } ) end = time.time() print(f"\n⏱ Inference Time: {end - start:.2f} sec") # --- Validate format --- validate_model_output(predictions_dict) # --- Save JSON --- with open("last_output.json", "w") as f: json.dump(predictions_dict, f, indent=4) print(" Saved JSON to last_output.json\n") # --- Draw Visualization --- annotated_image = draw_predictions(filepath, predictions_dict) pretty_json = json.dumps(predictions_dict, indent=4) return annotated_image, pretty_json # ------------------------------------------------------ # GRADIO UI # ------------------------------------------------------ iface = gr.Interface( fn=inference, inputs=gr.Image(type="pil"), outputs=[ gr.Image(label="Detection + Classification Output"), gr.JSON(label="Raw Model Output"), ], title=" SpeciesNet Wildlife Detector + Classifier", description="Upload a wildlife camera image.", ) iface.launch()