import os import io import json # Add json import import gradio as gr import numpy as np from PIL import Image, ImageDraw, ImageFont # Add imports for drawing import requests from fastapi import FastAPI, Form, UploadFile, HTTPException, File # Import Form, UploadFile, HTTPException, File from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List, Dict, Any import traceback # For detailed error logging import logging # Add logging import from fastapi import Request # Import Request import tempfile # Use tempfile for safer temporary file handling import math # Added for distance calculation # Import from utility files from detection_utils import PREDEFINED_CLASSES, run_yoloworld_detection, expand_synonyms # Import the new speech processing function from speech_utils import process_audio # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) profile_models = {} profile_class_maps = {} # Store the name mapping for each profile model dynamic_model = None # Create FastAPI app app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Proximity Filtering Configuration --- PROXIMITY_THRESHOLD = 50.0 # Pixels. Adjust as needed. PREFERRED_LABEL_FOR_FILTERING = "auto rickshaw" # --- Helper for Proximity Filtering --- def _calculate_distance(center1: List[float], center2: List[float]) -> float: if len(center1) != 2 or len(center2) != 2: # Should not happen with valid detection data, but good to be safe return float('inf') return math.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) def filter_close_detections( detections: List[Dict[str, Any]], distance_threshold: float, preferred_label: str = "auto rickshaw" ) -> List[Dict[str, Any]]: if not detections: return [] num_detections = len(detections) processed_mask = [False] * num_detections final_filtered_list = [] for i in range(num_detections): if processed_mask[i]: continue # Start a new cluster current_cluster_indices = [] # Use a list as a queue for BFS queue = [] # Seed the queue with the current unprocessed detection queue.append(i) processed_mask[i] = True head = 0 while head < len(queue): current_idx = queue[head] head += 1 current_cluster_indices.append(current_idx) # Add to current cluster # Check against all other detections for j in range(num_detections): if not processed_mask[j]: # Ensure 'centre' key exists and is valid before calculating distance center1 = detections[current_idx].get('centre') center2 = detections[j].get('centre') if not (isinstance(center1, list) and len(center1) == 2 and isinstance(center2, list) and len(center2) == 2): # Log warning or skip if center data is missing/malformed # logger.warning(f"Skipping proximity check due to missing/malformed center data for detection indices {current_idx}, {j}") continue dist = _calculate_distance(center1, center2) if dist < distance_threshold: processed_mask[j] = True queue.append(j) # Add to queue to explore its neighbors # We have a cluster (all indices in current_cluster_indices). # Now select the best one from it based on preference and confidence. if not current_cluster_indices: # Should not happen if loop started continue cluster_detections = [detections[k] for k in current_cluster_indices] preferred_detections_in_cluster = [ d for d in cluster_detections if d.get('label_en', '').lower() == preferred_label.lower() ] chosen_detection = None if preferred_detections_in_cluster: # If preferred label is present, pick the one with highest confidence among them chosen_detection = max(preferred_detections_in_cluster, key=lambda d: d.get('confidence', 0.0)) else: # Otherwise, pick the one with highest confidence from the whole cluster if cluster_detections: chosen_detection = max(cluster_detections, key=lambda d: d.get('confidence', 0.0)) if chosen_detection: final_filtered_list.append(chosen_detection) return final_filtered_list # --- Pydantic Models --- class DetectionResponse(BaseModel): objects: List[Dict[str, Any]] count: int profile_used: str classes_used: List[str] # Return the actual list used for detection status: str = "success" message: Optional[str] = None # --- API Endpoints --- # Add the new consolidated speech endpoint @app.post("/api/speech") async def handle_speech( audio: UploadFile = File(...), lang1: str = Form(...), lang2: str = Form(...) ): """ Receives an audio file, transcribes it using Whisper (detecting between lang1 and lang2), and translates the result to the other language using googletrans. """ try: logger.info(f"Received speech processing request. Lang1: {lang1}, Lang2: {lang2}, File: {audio.filename}") # Read audio file content audio_bytes = await audio.read() if not audio_bytes: raise HTTPException(status_code=400, detail="Received empty audio file.") # Process using the utility function result = await process_audio(audio_bytes, lang1, lang2) if result is None: raise HTTPException(status_code=500, detail="Failed to process audio.") if "error" in result: # Handle specific errors returned by process_audio raise HTTPException(status_code=400, detail=result["error"]) logger.info(f"Speech processing successful. Detected: {result.get('detected_language')}") return result except HTTPException as http_exc: # Re-raise HTTPExceptions directly raise http_exc except Exception as e: logger.error(f"Error in /api/speech endpoint: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") finally: # Ensure the uploaded file stream is closed if audio: await audio.close() # Keep /api/detect_objects endpoint @app.post("/api/detect_objects", response_model=DetectionResponse) async def detect_objects_yolo_world( request: Request, # Keep for logging if needed image: UploadFile = File(...), # Revert back to File(...) profile: str = Form("casual"), extra_words: Optional[str] = Form(None), confidence: float = Form(0.2, ge=0.0, le=1.0), # Lowered default confidence iou: float = Form(0.50, ge=0.0, le=1.0) # Lowered default IoU (NMS threshold) ): profile_lower = profile.lower() if profile_lower not in PREDEFINED_CLASSES: logger.warning(f"Invalid profile '{profile}' received, defaulting to 'casual'.") profile_lower = "casual" # Default to casual # --- Image Loading --- try: logger.info("Reading image bytes...") image_bytes = await image.read() if not image_bytes: raise ValueError("Received empty image file.") img = Image.open(io.BytesIO(image_bytes)).convert("RGB") logger.info(f"Image loaded: {img.width}x{img.height}") except Exception as e: logger.error(f"Image reading/loading error: {e}", exc_info=True) raise HTTPException(status_code=400, detail=f"Invalid or unreadable image file: {e}") finally: if image: await image.close() detections_raw = [] final_class_list_for_response = [] try: # Build class list from profile and extra_words initial_classes = set(PREDEFINED_CLASSES[profile_lower]) if extra_words and extra_words.strip(): try: words = json.loads(extra_words) if extra_words.strip().startswith('[') else extra_words.split(',') initial_classes |= set(str(w).lower().strip() for w in words if w and str(w).strip()) except Exception as e: logger.warning(f"Could not parse extra_words '{extra_words}', ignoring. Error: {e}") # Expand with synonyms expanded_classes = set(expand_synonyms(list(initial_classes))) final_class_list_for_response = sorted(list(expanded_classes)) # Run Roboflow detection detections_raw = run_yoloworld_detection( img, expanded_classes, confidence_threshold=confidence, iou_threshold=iou, profile=profile_lower ) # Filter detections by proximity filtered_detections = filter_close_detections( detections_raw, PROXIMITY_THRESHOLD, preferred_label=PREFERRED_LABEL_FOR_FILTERING ) logger.info(f"Detection complete. Found {len(detections_raw)} raw objects, {len(filtered_detections)} after proximity filtering.") return DetectionResponse( objects=filtered_detections, count=len(filtered_detections), profile_used=profile_lower, classes_used=final_class_list_for_response, status="success" ) except Exception as e: logger.error(f"Error during detection: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Internal server error during detection: {e}") # --- Gradio UI Functions --- # # Keep detect_objects_ui function # def detect_objects_ui(image_pil: Image.Image, profile: str, confidence: float, iou: float): # Add iou parameter # """Gradio function for YOLO-World object detection.""" # # Create a placeholder image for errors or no input # placeholder_img = Image.new('RGB', (640, 480), color = (150, 150, 150)) # draw_placeholder = ImageDraw.Draw(placeholder_img) # if image_pil is None: # draw_placeholder.text((10, 10), "Please upload an image.", fill=(255,255,255)) # return placeholder_img, "Please upload an image." # # Check if the correct model structure is available # profile_lower = profile.lower() # if profile_lower not in profile_models: # error_msg = f"Error: Model for profile '{profile_lower}' not loaded." # logger.error(f"UI requested profile '{profile_lower}' but model not loaded.") # # Return original image with error drawn on it # try: # error_img_out = image_pil.copy() # draw_error = ImageDraw.Draw(error_img_out) # draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default()) # return error_img_out, error_msg # except Exception: # Fallback if drawing on input fails # draw_placeholder.text((10, 10), error_msg, fill="red") # return placeholder_img, error_msg # model_to_use = profile_models[profile_lower] # name_map_to_use = profile_class_maps[profile_lower] # try: # # Ensure image is PIL Image and in RGB # if not isinstance(image_pil, Image.Image): # if isinstance(image_pil, np.ndarray): # image_pil = Image.fromarray(image_pil).convert("RGB") # else: # error_msg = "Error: Invalid image input type." # draw_placeholder.text((10, 10), error_msg, fill="red") # return placeholder_img, error_msg # else: # image_pil = image_pil.convert("RGB") # # Run detection using the pre-configured model # logger.info(f"Running YOLO-World detection (UI) with profile: {profile_lower}, confidence: {confidence}, iou: {iou}") # results = model_to_use.predict(image_pil, conf=confidence, iou=iou, verbose=False) # # Process results using the helper and the stored map # original_w, original_h = image_pil.width, image_pil.height # if results and results[0] and results[0].orig_shape: # original_h, original_w = results[0].orig_shape[:2] # detections = process_prediction_results( # results, original_w, original_h, name_map_to_use # ) # # Draw boxes on a copy of the image for Gradio output # output_image = image_pil.copy() # draw = ImageDraw.Draw(output_image) # try: # font = ImageFont.truetype("arial.ttf", 15) # except IOError: # font = ImageFont.load_default() # labels = [] # if not detections: # labels.append("No objects detected.") # else: # for det in detections: # box = det['box'] # label = f"{det['class_name']}: {det['confidence']:.2f}" # labels.append(label) # color = "red" # draw.rectangle( # [(box['x1'], box['y1']), (box['x2'], box['y2'])], # outline=color, width=3 # ) # text_position = (box['x1'], box['y1'] - 15 if box['y1'] > 15 else box['y1']) # # Use textbbox for better background calculation # try: # text_bbox = draw.textbbox(text_position, label, font=font) # # Adjust background size slightly # bg_coords = (text_bbox[0]-1, text_bbox[1]-1, text_bbox[2]+1, text_bbox[3]+1) # draw.rectangle(bg_coords, fill=color) # draw.text(text_position, label, fill="white", font=font) # except AttributeError: # Fallback for older Pillow versions without textbbox # draw.text(text_position, label, fill=color, font=font) # logger.info(f"UI Detection Results: {labels}") # return output_image, "\n".join(labels) # except Exception as e: # error_msg = f"Error: {str(e)}" # logger.error(f"Error in detect_objects_ui: {e}", exc_info=True) # # Return original image with error message drawn # try: # error_img_out = image_pil.copy() # draw_error = ImageDraw.Draw(error_img_out) # draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default()) # return error_img_out, error_msg # except Exception: # Fallback if drawing on input fails # draw_placeholder.text((10, 10), error_msg, fill="red") # return placeholder_img, error_msg # # --- Create Gradio Interface --- # # Add theme and descriptions # theme = gr.themes.Soft() # Example theme # with gr.Blocks(title="IPD-Lingual API", theme=theme) as demo: # gr.Markdown("# IPD-Lingual: Speech & Vision API") # gr.Markdown("An API providing speech transcription/translation and object detection capabilities.") # with gr.Tab("Home / About"): # gr.Markdown("## Welcome!") # gr.Markdown( # """ # This application provides two main functionalities accessible via API endpoints and a demonstration UI: # 1. **Speech Processing (`/api/speech`):** # * Accepts an audio file and two language codes (e.g., 'en', 'es'). # * Uses **OpenAI Whisper (base model)** to transcribe the audio, automatically detecting which of the two provided languages is spoken. # * Uses the **googletrans library** (unofficial Google Translate API) to translate the transcribed text into the *other* provided language. # * Returns the detected language, original transcription, and translation. # 2. **Object Detection (`/api/detect_objects`):** # * Accepts an image file, a detection profile (e.g., 'casual', 'vehicles'), optional extra object names, confidence threshold, and IoU threshold. # * Uses **YOLO-World (yolov8l-worldv2.pt)**, a powerful zero-shot object detection model from Ultralytics. # * It can detect objects based on predefined profiles or dynamically based on user-provided text prompts (extra words). # * Returns a list of detected objects with their bounding boxes, class names, and confidence scores. # Use the tabs above to try out the object detection functionality or see the API endpoint details below. # *(Note: The speech processing functionality is currently only available via the API endpoint).* # """ # ) # gr.Markdown("---") # gr.Markdown("### API Endpoint Summary") # gr.Markdown("- **POST `/api/speech`**: Transcribe and Translate audio.\n - **Type**: `multipart/form-data`\n - **Fields**: `audio` (file), `lang1` (string), `lang2` (string)") # gr.Markdown("- **POST `/api/detect_objects`**: Detect objects using YOLO-World.\n - **Type**: `multipart/form-data`\n - **Fields**: `image` (file), `profile` (string), `extra_words` (string, optional, comma-separated or JSON list), `confidence` (float, optional), `iou` (float, optional)") # # Keep the "Object Detection" Tab # with gr.Tab("Object Detection Demo"): # gr.Markdown("## Detect Objects in Image (using YOLO-World)") # gr.Markdown("Upload an image and select a detection profile. The model will identify objects belonging to that profile.") # with gr.Row(): # with gr.Column(scale=1): # Input column slightly smaller # image_input = gr.Image(type="pil", label="Upload Image") # profile_select = gr.Dropdown( # choices=sorted(list(PREDEFINED_CLASSES.keys())), # value="casual", # label="Detection Profile" # ) # confidence_slider = gr.Slider( # minimum=0.001, maximum=1.0, value=0.01, step=0.001, # label="Confidence Threshold" # ) # iou_slider = gr.Slider( # minimum=0.01, maximum=1.0, value=0.2, step=0.01, # label="IoU Threshold (NMS)" # ) # detect_btn = gr.Button("Detect Objects", variant="primary") # Make button primary # with gr.Column(scale=2): # Output column larger # image_output = gr.Image(label="Detection Result", interactive=False) # Output not interactive # labels_output = gr.Textbox(label="Detected Objects", lines=10, interactive=False) # # Ensure the click event is correctly wired # detect_btn.click( # fn=detect_objects_ui, # inputs=[image_input, profile_select, confidence_slider, iou_slider], # outputs=[image_output, labels_output] # ) # # Mount both FastAPI and Gradio # # Ensure the Gradio app uses the FastAPI instance `app` # app = gr.mount_gradio_app(app, demo, path="/") # # ... (rest of the file remains the same) ... # if __name__ == "__main__": # import uvicorn # # Check if YOLO models initialized before starting server # # Update check to use the new model variables # if not profile_models or dynamic_model is None: # logger.error(f"CRITICAL: One or more YOLO-World models ({MODEL_NAME}) failed to initialize. API endpoint /api/detect_objects might not work correctly.") # # Decide if you want to exit or run with degraded functionality # # exit(1) # Optional: exit if model loading fails # else: # logger.info("All required YOLO models initialized successfully.") # print("Starting Uvicorn server on http://0.0.0.0:7860") # uvicorn.run(app, host="0.0.0.0", port=7860) if __name__ == "__main__": import uvicorn # Ensure YOLO-World models from detection_utils are loaded (conceptual check) # A more robust check would involve calling a function in detection_utils # or checking the YOLOWORLD_MODELS dictionary directly if it were accessible here. # For now, we rely on detection_utils to log errors if models fail to load. # if not detection_utils.YOLOWORLD_MODELS: # This line would cause an error if detection_utils is not imported # logger.error("CRITICAL: YOLO-World models from detection_utils may not have initialized.") # else: # logger.info("YOLO-World models in detection_utils assumed to be loading/loaded.") # A simple check based on previous logic (profile_models and dynamic_model were for a different setup) # We can infer model readiness by checking if the PREDEFINED_CLASSES (used by detection_utils) has keys. if not PREDEFINED_CLASSES: # A basic check, actual model loading is in detection_utils logger.error(f"CRITICAL: PREDEFINED_CLASSES is empty. YOLO-World models might not be configured in detection_utils.") else: logger.info("Detection profiles are configured. YOLO-World model loading is handled in detection_utils.") print("Starting Uvicorn server on http://0.0.0.0:7860") uvicorn.run(app, host="0.0.0.0", port=7860)