import os | |
import io | |
import json # Add json import | |
import gradio as gr | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont # Add imports for drawing | |
import requests | |
from fastapi import FastAPI, Form, UploadFile, HTTPException, File # Import Form, UploadFile, HTTPException, File | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import Optional, List, Dict, Any | |
import traceback # For detailed error logging | |
import logging # Add logging import | |
from fastapi import Request # Import Request | |
import tempfile # Use tempfile for safer temporary file handling | |
import math # Added for distance calculation | |
# Import from utility files | |
from detection_utils import PREDEFINED_CLASSES, run_yoloworld_detection, expand_synonyms | |
# Import the new speech processing function | |
from speech_utils import process_audio | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
profile_models = {} | |
profile_class_maps = {} # Store the name mapping for each profile model | |
dynamic_model = None | |
# Create FastAPI app | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --- Proximity Filtering Configuration --- | |
PROXIMITY_THRESHOLD = 50.0 # Pixels. Adjust as needed. | |
PREFERRED_LABEL_FOR_FILTERING = "auto rickshaw" | |
# --- Helper for Proximity Filtering --- | |
def _calculate_distance(center1: List[float], center2: List[float]) -> float: | |
if len(center1) != 2 or len(center2) != 2: | |
# Should not happen with valid detection data, but good to be safe | |
return float('inf') | |
return math.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) | |
def filter_close_detections( | |
detections: List[Dict[str, Any]], | |
distance_threshold: float, | |
preferred_label: str = "auto rickshaw" | |
) -> List[Dict[str, Any]]: | |
if not detections: | |
return [] | |
num_detections = len(detections) | |
processed_mask = [False] * num_detections | |
final_filtered_list = [] | |
for i in range(num_detections): | |
if processed_mask[i]: | |
continue | |
# Start a new cluster | |
current_cluster_indices = [] | |
# Use a list as a queue for BFS | |
queue = [] | |
# Seed the queue with the current unprocessed detection | |
queue.append(i) | |
processed_mask[i] = True | |
head = 0 | |
while head < len(queue): | |
current_idx = queue[head] | |
head += 1 | |
current_cluster_indices.append(current_idx) # Add to current cluster | |
# Check against all other detections | |
for j in range(num_detections): | |
if not processed_mask[j]: | |
# Ensure 'centre' key exists and is valid before calculating distance | |
center1 = detections[current_idx].get('centre') | |
center2 = detections[j].get('centre') | |
if not (isinstance(center1, list) and len(center1) == 2 and | |
isinstance(center2, list) and len(center2) == 2): | |
# Log warning or skip if center data is missing/malformed | |
# logger.warning(f"Skipping proximity check due to missing/malformed center data for detection indices {current_idx}, {j}") | |
continue | |
dist = _calculate_distance(center1, center2) | |
if dist < distance_threshold: | |
processed_mask[j] = True | |
queue.append(j) # Add to queue to explore its neighbors | |
# We have a cluster (all indices in current_cluster_indices). | |
# Now select the best one from it based on preference and confidence. | |
if not current_cluster_indices: # Should not happen if loop started | |
continue | |
cluster_detections = [detections[k] for k in current_cluster_indices] | |
preferred_detections_in_cluster = [ | |
d for d in cluster_detections if d.get('label_en', '').lower() == preferred_label.lower() | |
] | |
chosen_detection = None | |
if preferred_detections_in_cluster: | |
# If preferred label is present, pick the one with highest confidence among them | |
chosen_detection = max(preferred_detections_in_cluster, key=lambda d: d.get('confidence', 0.0)) | |
else: | |
# Otherwise, pick the one with highest confidence from the whole cluster | |
if cluster_detections: | |
chosen_detection = max(cluster_detections, key=lambda d: d.get('confidence', 0.0)) | |
if chosen_detection: | |
final_filtered_list.append(chosen_detection) | |
return final_filtered_list | |
# --- Pydantic Models --- | |
class DetectionResponse(BaseModel): | |
objects: List[Dict[str, Any]] | |
count: int | |
profile_used: str | |
classes_used: List[str] # Return the actual list used for detection | |
status: str = "success" | |
message: Optional[str] = None | |
# --- API Endpoints --- | |
# Add the new consolidated speech endpoint | |
async def handle_speech( | |
audio: UploadFile = File(...), | |
lang1: str = Form(...), | |
lang2: str = Form(...) | |
): | |
""" | |
Receives an audio file, transcribes it using Whisper (detecting between lang1 and lang2), | |
and translates the result to the other language using googletrans. | |
""" | |
try: | |
logger.info(f"Received speech processing request. Lang1: {lang1}, Lang2: {lang2}, File: {audio.filename}") | |
# Read audio file content | |
audio_bytes = await audio.read() | |
if not audio_bytes: | |
raise HTTPException(status_code=400, detail="Received empty audio file.") | |
# Process using the utility function | |
result = await process_audio(audio_bytes, lang1, lang2) | |
if result is None: | |
raise HTTPException(status_code=500, detail="Failed to process audio.") | |
if "error" in result: # Handle specific errors returned by process_audio | |
raise HTTPException(status_code=400, detail=result["error"]) | |
logger.info(f"Speech processing successful. Detected: {result.get('detected_language')}") | |
return result | |
except HTTPException as http_exc: | |
# Re-raise HTTPExceptions directly | |
raise http_exc | |
except Exception as e: | |
logger.error(f"Error in /api/speech endpoint: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") | |
finally: | |
# Ensure the uploaded file stream is closed | |
if audio: | |
await audio.close() | |
# Keep /api/detect_objects endpoint | |
async def detect_objects_yolo_world( | |
request: Request, # Keep for logging if needed | |
image: UploadFile = File(...), # Revert back to File(...) | |
profile: str = Form("casual"), | |
extra_words: Optional[str] = Form(None), | |
confidence: float = Form(0.2, ge=0.0, le=1.0), # Lowered default confidence | |
iou: float = Form(0.50, ge=0.0, le=1.0) # Lowered default IoU (NMS threshold) | |
): | |
profile_lower = profile.lower() | |
if profile_lower not in PREDEFINED_CLASSES: | |
logger.warning(f"Invalid profile '{profile}' received, defaulting to 'casual'.") | |
profile_lower = "casual" # Default to casual | |
# --- Image Loading --- | |
try: | |
logger.info("Reading image bytes...") | |
image_bytes = await image.read() | |
if not image_bytes: raise ValueError("Received empty image file.") | |
img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
logger.info(f"Image loaded: {img.width}x{img.height}") | |
except Exception as e: | |
logger.error(f"Image reading/loading error: {e}", exc_info=True) | |
raise HTTPException(status_code=400, detail=f"Invalid or unreadable image file: {e}") | |
finally: | |
if image: await image.close() | |
detections_raw = [] | |
final_class_list_for_response = [] | |
try: | |
# Build class list from profile and extra_words | |
initial_classes = set(PREDEFINED_CLASSES[profile_lower]) | |
if extra_words and extra_words.strip(): | |
try: | |
words = json.loads(extra_words) if extra_words.strip().startswith('[') else extra_words.split(',') | |
initial_classes |= set(str(w).lower().strip() for w in words if w and str(w).strip()) | |
except Exception as e: | |
logger.warning(f"Could not parse extra_words '{extra_words}', ignoring. Error: {e}") | |
# Expand with synonyms | |
expanded_classes = set(expand_synonyms(list(initial_classes))) | |
final_class_list_for_response = sorted(list(expanded_classes)) | |
# Run Roboflow detection | |
detections_raw = run_yoloworld_detection( | |
img, | |
expanded_classes, | |
confidence_threshold=confidence, | |
iou_threshold=iou, | |
profile=profile_lower | |
) | |
# Filter detections by proximity | |
filtered_detections = filter_close_detections( | |
detections_raw, | |
PROXIMITY_THRESHOLD, | |
preferred_label=PREFERRED_LABEL_FOR_FILTERING | |
) | |
logger.info(f"Detection complete. Found {len(detections_raw)} raw objects, {len(filtered_detections)} after proximity filtering.") | |
return DetectionResponse( | |
objects=filtered_detections, | |
count=len(filtered_detections), | |
profile_used=profile_lower, | |
classes_used=final_class_list_for_response, | |
status="success" | |
) | |
except Exception as e: | |
logger.error(f"Error during detection: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"Internal server error during detection: {e}") | |
# --- Gradio UI Functions --- | |
# # Keep detect_objects_ui function | |
# def detect_objects_ui(image_pil: Image.Image, profile: str, confidence: float, iou: float): # Add iou parameter | |
# """Gradio function for YOLO-World object detection.""" | |
# # Create a placeholder image for errors or no input | |
# placeholder_img = Image.new('RGB', (640, 480), color = (150, 150, 150)) | |
# draw_placeholder = ImageDraw.Draw(placeholder_img) | |
# if image_pil is None: | |
# draw_placeholder.text((10, 10), "Please upload an image.", fill=(255,255,255)) | |
# return placeholder_img, "Please upload an image." | |
# # Check if the correct model structure is available | |
# profile_lower = profile.lower() | |
# if profile_lower not in profile_models: | |
# error_msg = f"Error: Model for profile '{profile_lower}' not loaded." | |
# logger.error(f"UI requested profile '{profile_lower}' but model not loaded.") | |
# # Return original image with error drawn on it | |
# try: | |
# error_img_out = image_pil.copy() | |
# draw_error = ImageDraw.Draw(error_img_out) | |
# draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default()) | |
# return error_img_out, error_msg | |
# except Exception: # Fallback if drawing on input fails | |
# draw_placeholder.text((10, 10), error_msg, fill="red") | |
# return placeholder_img, error_msg | |
# model_to_use = profile_models[profile_lower] | |
# name_map_to_use = profile_class_maps[profile_lower] | |
# try: | |
# # Ensure image is PIL Image and in RGB | |
# if not isinstance(image_pil, Image.Image): | |
# if isinstance(image_pil, np.ndarray): | |
# image_pil = Image.fromarray(image_pil).convert("RGB") | |
# else: | |
# error_msg = "Error: Invalid image input type." | |
# draw_placeholder.text((10, 10), error_msg, fill="red") | |
# return placeholder_img, error_msg | |
# else: | |
# image_pil = image_pil.convert("RGB") | |
# # Run detection using the pre-configured model | |
# logger.info(f"Running YOLO-World detection (UI) with profile: {profile_lower}, confidence: {confidence}, iou: {iou}") | |
# results = model_to_use.predict(image_pil, conf=confidence, iou=iou, verbose=False) | |
# # Process results using the helper and the stored map | |
# original_w, original_h = image_pil.width, image_pil.height | |
# if results and results[0] and results[0].orig_shape: | |
# original_h, original_w = results[0].orig_shape[:2] | |
# detections = process_prediction_results( | |
# results, original_w, original_h, name_map_to_use | |
# ) | |
# # Draw boxes on a copy of the image for Gradio output | |
# output_image = image_pil.copy() | |
# draw = ImageDraw.Draw(output_image) | |
# try: | |
# font = ImageFont.truetype("arial.ttf", 15) | |
# except IOError: | |
# font = ImageFont.load_default() | |
# labels = [] | |
# if not detections: | |
# labels.append("No objects detected.") | |
# else: | |
# for det in detections: | |
# box = det['box'] | |
# label = f"{det['class_name']}: {det['confidence']:.2f}" | |
# labels.append(label) | |
# color = "red" | |
# draw.rectangle( | |
# [(box['x1'], box['y1']), (box['x2'], box['y2'])], | |
# outline=color, width=3 | |
# ) | |
# text_position = (box['x1'], box['y1'] - 15 if box['y1'] > 15 else box['y1']) | |
# # Use textbbox for better background calculation | |
# try: | |
# text_bbox = draw.textbbox(text_position, label, font=font) | |
# # Adjust background size slightly | |
# bg_coords = (text_bbox[0]-1, text_bbox[1]-1, text_bbox[2]+1, text_bbox[3]+1) | |
# draw.rectangle(bg_coords, fill=color) | |
# draw.text(text_position, label, fill="white", font=font) | |
# except AttributeError: # Fallback for older Pillow versions without textbbox | |
# draw.text(text_position, label, fill=color, font=font) | |
# logger.info(f"UI Detection Results: {labels}") | |
# return output_image, "\n".join(labels) | |
# except Exception as e: | |
# error_msg = f"Error: {str(e)}" | |
# logger.error(f"Error in detect_objects_ui: {e}", exc_info=True) | |
# # Return original image with error message drawn | |
# try: | |
# error_img_out = image_pil.copy() | |
# draw_error = ImageDraw.Draw(error_img_out) | |
# draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default()) | |
# return error_img_out, error_msg | |
# except Exception: # Fallback if drawing on input fails | |
# draw_placeholder.text((10, 10), error_msg, fill="red") | |
# return placeholder_img, error_msg | |
# # --- Create Gradio Interface --- | |
# # Add theme and descriptions | |
# theme = gr.themes.Soft() # Example theme | |
# with gr.Blocks(title="IPD-Lingual API", theme=theme) as demo: | |
# gr.Markdown("# IPD-Lingual: Speech & Vision API") | |
# gr.Markdown("An API providing speech transcription/translation and object detection capabilities.") | |
# with gr.Tab("Home / About"): | |
# gr.Markdown("## Welcome!") | |
# gr.Markdown( | |
# """ | |
# This application provides two main functionalities accessible via API endpoints and a demonstration UI: | |
# 1. **Speech Processing (`/api/speech`):** | |
# * Accepts an audio file and two language codes (e.g., 'en', 'es'). | |
# * Uses **OpenAI Whisper (base model)** to transcribe the audio, automatically detecting which of the two provided languages is spoken. | |
# * Uses the **googletrans library** (unofficial Google Translate API) to translate the transcribed text into the *other* provided language. | |
# * Returns the detected language, original transcription, and translation. | |
# 2. **Object Detection (`/api/detect_objects`):** | |
# * Accepts an image file, a detection profile (e.g., 'casual', 'vehicles'), optional extra object names, confidence threshold, and IoU threshold. | |
# * Uses **YOLO-World (yolov8l-worldv2.pt)**, a powerful zero-shot object detection model from Ultralytics. | |
# * It can detect objects based on predefined profiles or dynamically based on user-provided text prompts (extra words). | |
# * Returns a list of detected objects with their bounding boxes, class names, and confidence scores. | |
# Use the tabs above to try out the object detection functionality or see the API endpoint details below. | |
# *(Note: The speech processing functionality is currently only available via the API endpoint).* | |
# """ | |
# ) | |
# gr.Markdown("---") | |
# gr.Markdown("### API Endpoint Summary") | |
# gr.Markdown("- **POST `/api/speech`**: Transcribe and Translate audio.\n - **Type**: `multipart/form-data`\n - **Fields**: `audio` (file), `lang1` (string), `lang2` (string)") | |
# gr.Markdown("- **POST `/api/detect_objects`**: Detect objects using YOLO-World.\n - **Type**: `multipart/form-data`\n - **Fields**: `image` (file), `profile` (string), `extra_words` (string, optional, comma-separated or JSON list), `confidence` (float, optional), `iou` (float, optional)") | |
# # Keep the "Object Detection" Tab | |
# with gr.Tab("Object Detection Demo"): | |
# gr.Markdown("## Detect Objects in Image (using YOLO-World)") | |
# gr.Markdown("Upload an image and select a detection profile. The model will identify objects belonging to that profile.") | |
# with gr.Row(): | |
# with gr.Column(scale=1): # Input column slightly smaller | |
# image_input = gr.Image(type="pil", label="Upload Image") | |
# profile_select = gr.Dropdown( | |
# choices=sorted(list(PREDEFINED_CLASSES.keys())), | |
# value="casual", | |
# label="Detection Profile" | |
# ) | |
# confidence_slider = gr.Slider( | |
# minimum=0.001, maximum=1.0, value=0.01, step=0.001, | |
# label="Confidence Threshold" | |
# ) | |
# iou_slider = gr.Slider( | |
# minimum=0.01, maximum=1.0, value=0.2, step=0.01, | |
# label="IoU Threshold (NMS)" | |
# ) | |
# detect_btn = gr.Button("Detect Objects", variant="primary") # Make button primary | |
# with gr.Column(scale=2): # Output column larger | |
# image_output = gr.Image(label="Detection Result", interactive=False) # Output not interactive | |
# labels_output = gr.Textbox(label="Detected Objects", lines=10, interactive=False) | |
# # Ensure the click event is correctly wired | |
# detect_btn.click( | |
# fn=detect_objects_ui, | |
# inputs=[image_input, profile_select, confidence_slider, iou_slider], | |
# outputs=[image_output, labels_output] | |
# ) | |
# # Mount both FastAPI and Gradio | |
# # Ensure the Gradio app uses the FastAPI instance `app` | |
# app = gr.mount_gradio_app(app, demo, path="/") | |
# # ... (rest of the file remains the same) ... | |
# if __name__ == "__main__": | |
# import uvicorn | |
# # Check if YOLO models initialized before starting server | |
# # Update check to use the new model variables | |
# if not profile_models or dynamic_model is None: | |
# logger.error(f"CRITICAL: One or more YOLO-World models ({MODEL_NAME}) failed to initialize. API endpoint /api/detect_objects might not work correctly.") | |
# # Decide if you want to exit or run with degraded functionality | |
# # exit(1) # Optional: exit if model loading fails | |
# else: | |
# logger.info("All required YOLO models initialized successfully.") | |
# print("Starting Uvicorn server on http://0.0.0.0:7860") | |
# uvicorn.run(app, host="0.0.0.0", port=7860) | |
if __name__ == "__main__": | |
import uvicorn | |
# Ensure YOLO-World models from detection_utils are loaded (conceptual check) | |
# A more robust check would involve calling a function in detection_utils | |
# or checking the YOLOWORLD_MODELS dictionary directly if it were accessible here. | |
# For now, we rely on detection_utils to log errors if models fail to load. | |
# if not detection_utils.YOLOWORLD_MODELS: # This line would cause an error if detection_utils is not imported | |
# logger.error("CRITICAL: YOLO-World models from detection_utils may not have initialized.") | |
# else: | |
# logger.info("YOLO-World models in detection_utils assumed to be loading/loaded.") | |
# A simple check based on previous logic (profile_models and dynamic_model were for a different setup) | |
# We can infer model readiness by checking if the PREDEFINED_CLASSES (used by detection_utils) has keys. | |
if not PREDEFINED_CLASSES: # A basic check, actual model loading is in detection_utils | |
logger.error(f"CRITICAL: PREDEFINED_CLASSES is empty. YOLO-World models might not be configured in detection_utils.") | |
else: | |
logger.info("Detection profiles are configured. YOLO-World model loading is handled in detection_utils.") | |
print("Starting Uvicorn server on http://0.0.0.0:7860") | |
uvicorn.run(app, host="0.0.0.0", port=7860) |