Lingual / app.py
MonilM's picture
Jhol
9053779
import os
import io
import json # Add json import
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFont # Add imports for drawing
import requests
from fastapi import FastAPI, Form, UploadFile, HTTPException, File # Import Form, UploadFile, HTTPException, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import traceback # For detailed error logging
import logging # Add logging import
from fastapi import Request # Import Request
import tempfile # Use tempfile for safer temporary file handling
import math # Added for distance calculation
# Import from utility files
from detection_utils import PREDEFINED_CLASSES, run_yoloworld_detection, expand_synonyms
# Import the new speech processing function
from speech_utils import process_audio
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
profile_models = {}
profile_class_maps = {} # Store the name mapping for each profile model
dynamic_model = None
# Create FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Proximity Filtering Configuration ---
PROXIMITY_THRESHOLD = 50.0 # Pixels. Adjust as needed.
PREFERRED_LABEL_FOR_FILTERING = "auto rickshaw"
# --- Helper for Proximity Filtering ---
def _calculate_distance(center1: List[float], center2: List[float]) -> float:
if len(center1) != 2 or len(center2) != 2:
# Should not happen with valid detection data, but good to be safe
return float('inf')
return math.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
def filter_close_detections(
detections: List[Dict[str, Any]],
distance_threshold: float,
preferred_label: str = "auto rickshaw"
) -> List[Dict[str, Any]]:
if not detections:
return []
num_detections = len(detections)
processed_mask = [False] * num_detections
final_filtered_list = []
for i in range(num_detections):
if processed_mask[i]:
continue
# Start a new cluster
current_cluster_indices = []
# Use a list as a queue for BFS
queue = []
# Seed the queue with the current unprocessed detection
queue.append(i)
processed_mask[i] = True
head = 0
while head < len(queue):
current_idx = queue[head]
head += 1
current_cluster_indices.append(current_idx) # Add to current cluster
# Check against all other detections
for j in range(num_detections):
if not processed_mask[j]:
# Ensure 'centre' key exists and is valid before calculating distance
center1 = detections[current_idx].get('centre')
center2 = detections[j].get('centre')
if not (isinstance(center1, list) and len(center1) == 2 and
isinstance(center2, list) and len(center2) == 2):
# Log warning or skip if center data is missing/malformed
# logger.warning(f"Skipping proximity check due to missing/malformed center data for detection indices {current_idx}, {j}")
continue
dist = _calculate_distance(center1, center2)
if dist < distance_threshold:
processed_mask[j] = True
queue.append(j) # Add to queue to explore its neighbors
# We have a cluster (all indices in current_cluster_indices).
# Now select the best one from it based on preference and confidence.
if not current_cluster_indices: # Should not happen if loop started
continue
cluster_detections = [detections[k] for k in current_cluster_indices]
preferred_detections_in_cluster = [
d for d in cluster_detections if d.get('label_en', '').lower() == preferred_label.lower()
]
chosen_detection = None
if preferred_detections_in_cluster:
# If preferred label is present, pick the one with highest confidence among them
chosen_detection = max(preferred_detections_in_cluster, key=lambda d: d.get('confidence', 0.0))
else:
# Otherwise, pick the one with highest confidence from the whole cluster
if cluster_detections:
chosen_detection = max(cluster_detections, key=lambda d: d.get('confidence', 0.0))
if chosen_detection:
final_filtered_list.append(chosen_detection)
return final_filtered_list
# --- Pydantic Models ---
class DetectionResponse(BaseModel):
objects: List[Dict[str, Any]]
count: int
profile_used: str
classes_used: List[str] # Return the actual list used for detection
status: str = "success"
message: Optional[str] = None
# --- API Endpoints ---
# Add the new consolidated speech endpoint
@app.post("/api/speech")
async def handle_speech(
audio: UploadFile = File(...),
lang1: str = Form(...),
lang2: str = Form(...)
):
"""
Receives an audio file, transcribes it using Whisper (detecting between lang1 and lang2),
and translates the result to the other language using googletrans.
"""
try:
logger.info(f"Received speech processing request. Lang1: {lang1}, Lang2: {lang2}, File: {audio.filename}")
# Read audio file content
audio_bytes = await audio.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Received empty audio file.")
# Process using the utility function
result = await process_audio(audio_bytes, lang1, lang2)
if result is None:
raise HTTPException(status_code=500, detail="Failed to process audio.")
if "error" in result: # Handle specific errors returned by process_audio
raise HTTPException(status_code=400, detail=result["error"])
logger.info(f"Speech processing successful. Detected: {result.get('detected_language')}")
return result
except HTTPException as http_exc:
# Re-raise HTTPExceptions directly
raise http_exc
except Exception as e:
logger.error(f"Error in /api/speech endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
finally:
# Ensure the uploaded file stream is closed
if audio:
await audio.close()
# Keep /api/detect_objects endpoint
@app.post("/api/detect_objects", response_model=DetectionResponse)
async def detect_objects_yolo_world(
request: Request, # Keep for logging if needed
image: UploadFile = File(...), # Revert back to File(...)
profile: str = Form("casual"),
extra_words: Optional[str] = Form(None),
confidence: float = Form(0.2, ge=0.0, le=1.0), # Lowered default confidence
iou: float = Form(0.50, ge=0.0, le=1.0) # Lowered default IoU (NMS threshold)
):
profile_lower = profile.lower()
if profile_lower not in PREDEFINED_CLASSES:
logger.warning(f"Invalid profile '{profile}' received, defaulting to 'casual'.")
profile_lower = "casual" # Default to casual
# --- Image Loading ---
try:
logger.info("Reading image bytes...")
image_bytes = await image.read()
if not image_bytes: raise ValueError("Received empty image file.")
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
logger.info(f"Image loaded: {img.width}x{img.height}")
except Exception as e:
logger.error(f"Image reading/loading error: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Invalid or unreadable image file: {e}")
finally:
if image: await image.close()
detections_raw = []
final_class_list_for_response = []
try:
# Build class list from profile and extra_words
initial_classes = set(PREDEFINED_CLASSES[profile_lower])
if extra_words and extra_words.strip():
try:
words = json.loads(extra_words) if extra_words.strip().startswith('[') else extra_words.split(',')
initial_classes |= set(str(w).lower().strip() for w in words if w and str(w).strip())
except Exception as e:
logger.warning(f"Could not parse extra_words '{extra_words}', ignoring. Error: {e}")
# Expand with synonyms
expanded_classes = set(expand_synonyms(list(initial_classes)))
final_class_list_for_response = sorted(list(expanded_classes))
# Run Roboflow detection
detections_raw = run_yoloworld_detection(
img,
expanded_classes,
confidence_threshold=confidence,
iou_threshold=iou,
profile=profile_lower
)
# Filter detections by proximity
filtered_detections = filter_close_detections(
detections_raw,
PROXIMITY_THRESHOLD,
preferred_label=PREFERRED_LABEL_FOR_FILTERING
)
logger.info(f"Detection complete. Found {len(detections_raw)} raw objects, {len(filtered_detections)} after proximity filtering.")
return DetectionResponse(
objects=filtered_detections,
count=len(filtered_detections),
profile_used=profile_lower,
classes_used=final_class_list_for_response,
status="success"
)
except Exception as e:
logger.error(f"Error during detection: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error during detection: {e}")
# --- Gradio UI Functions ---
# # Keep detect_objects_ui function
# def detect_objects_ui(image_pil: Image.Image, profile: str, confidence: float, iou: float): # Add iou parameter
# """Gradio function for YOLO-World object detection."""
# # Create a placeholder image for errors or no input
# placeholder_img = Image.new('RGB', (640, 480), color = (150, 150, 150))
# draw_placeholder = ImageDraw.Draw(placeholder_img)
# if image_pil is None:
# draw_placeholder.text((10, 10), "Please upload an image.", fill=(255,255,255))
# return placeholder_img, "Please upload an image."
# # Check if the correct model structure is available
# profile_lower = profile.lower()
# if profile_lower not in profile_models:
# error_msg = f"Error: Model for profile '{profile_lower}' not loaded."
# logger.error(f"UI requested profile '{profile_lower}' but model not loaded.")
# # Return original image with error drawn on it
# try:
# error_img_out = image_pil.copy()
# draw_error = ImageDraw.Draw(error_img_out)
# draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default())
# return error_img_out, error_msg
# except Exception: # Fallback if drawing on input fails
# draw_placeholder.text((10, 10), error_msg, fill="red")
# return placeholder_img, error_msg
# model_to_use = profile_models[profile_lower]
# name_map_to_use = profile_class_maps[profile_lower]
# try:
# # Ensure image is PIL Image and in RGB
# if not isinstance(image_pil, Image.Image):
# if isinstance(image_pil, np.ndarray):
# image_pil = Image.fromarray(image_pil).convert("RGB")
# else:
# error_msg = "Error: Invalid image input type."
# draw_placeholder.text((10, 10), error_msg, fill="red")
# return placeholder_img, error_msg
# else:
# image_pil = image_pil.convert("RGB")
# # Run detection using the pre-configured model
# logger.info(f"Running YOLO-World detection (UI) with profile: {profile_lower}, confidence: {confidence}, iou: {iou}")
# results = model_to_use.predict(image_pil, conf=confidence, iou=iou, verbose=False)
# # Process results using the helper and the stored map
# original_w, original_h = image_pil.width, image_pil.height
# if results and results[0] and results[0].orig_shape:
# original_h, original_w = results[0].orig_shape[:2]
# detections = process_prediction_results(
# results, original_w, original_h, name_map_to_use
# )
# # Draw boxes on a copy of the image for Gradio output
# output_image = image_pil.copy()
# draw = ImageDraw.Draw(output_image)
# try:
# font = ImageFont.truetype("arial.ttf", 15)
# except IOError:
# font = ImageFont.load_default()
# labels = []
# if not detections:
# labels.append("No objects detected.")
# else:
# for det in detections:
# box = det['box']
# label = f"{det['class_name']}: {det['confidence']:.2f}"
# labels.append(label)
# color = "red"
# draw.rectangle(
# [(box['x1'], box['y1']), (box['x2'], box['y2'])],
# outline=color, width=3
# )
# text_position = (box['x1'], box['y1'] - 15 if box['y1'] > 15 else box['y1'])
# # Use textbbox for better background calculation
# try:
# text_bbox = draw.textbbox(text_position, label, font=font)
# # Adjust background size slightly
# bg_coords = (text_bbox[0]-1, text_bbox[1]-1, text_bbox[2]+1, text_bbox[3]+1)
# draw.rectangle(bg_coords, fill=color)
# draw.text(text_position, label, fill="white", font=font)
# except AttributeError: # Fallback for older Pillow versions without textbbox
# draw.text(text_position, label, fill=color, font=font)
# logger.info(f"UI Detection Results: {labels}")
# return output_image, "\n".join(labels)
# except Exception as e:
# error_msg = f"Error: {str(e)}"
# logger.error(f"Error in detect_objects_ui: {e}", exc_info=True)
# # Return original image with error message drawn
# try:
# error_img_out = image_pil.copy()
# draw_error = ImageDraw.Draw(error_img_out)
# draw_error.text((10, 10), error_msg, fill="red", font=ImageFont.load_default())
# return error_img_out, error_msg
# except Exception: # Fallback if drawing on input fails
# draw_placeholder.text((10, 10), error_msg, fill="red")
# return placeholder_img, error_msg
# # --- Create Gradio Interface ---
# # Add theme and descriptions
# theme = gr.themes.Soft() # Example theme
# with gr.Blocks(title="IPD-Lingual API", theme=theme) as demo:
# gr.Markdown("# IPD-Lingual: Speech & Vision API")
# gr.Markdown("An API providing speech transcription/translation and object detection capabilities.")
# with gr.Tab("Home / About"):
# gr.Markdown("## Welcome!")
# gr.Markdown(
# """
# This application provides two main functionalities accessible via API endpoints and a demonstration UI:
# 1. **Speech Processing (`/api/speech`):**
# * Accepts an audio file and two language codes (e.g., 'en', 'es').
# * Uses **OpenAI Whisper (base model)** to transcribe the audio, automatically detecting which of the two provided languages is spoken.
# * Uses the **googletrans library** (unofficial Google Translate API) to translate the transcribed text into the *other* provided language.
# * Returns the detected language, original transcription, and translation.
# 2. **Object Detection (`/api/detect_objects`):**
# * Accepts an image file, a detection profile (e.g., 'casual', 'vehicles'), optional extra object names, confidence threshold, and IoU threshold.
# * Uses **YOLO-World (yolov8l-worldv2.pt)**, a powerful zero-shot object detection model from Ultralytics.
# * It can detect objects based on predefined profiles or dynamically based on user-provided text prompts (extra words).
# * Returns a list of detected objects with their bounding boxes, class names, and confidence scores.
# Use the tabs above to try out the object detection functionality or see the API endpoint details below.
# *(Note: The speech processing functionality is currently only available via the API endpoint).*
# """
# )
# gr.Markdown("---")
# gr.Markdown("### API Endpoint Summary")
# gr.Markdown("- **POST `/api/speech`**: Transcribe and Translate audio.\n - **Type**: `multipart/form-data`\n - **Fields**: `audio` (file), `lang1` (string), `lang2` (string)")
# gr.Markdown("- **POST `/api/detect_objects`**: Detect objects using YOLO-World.\n - **Type**: `multipart/form-data`\n - **Fields**: `image` (file), `profile` (string), `extra_words` (string, optional, comma-separated or JSON list), `confidence` (float, optional), `iou` (float, optional)")
# # Keep the "Object Detection" Tab
# with gr.Tab("Object Detection Demo"):
# gr.Markdown("## Detect Objects in Image (using YOLO-World)")
# gr.Markdown("Upload an image and select a detection profile. The model will identify objects belonging to that profile.")
# with gr.Row():
# with gr.Column(scale=1): # Input column slightly smaller
# image_input = gr.Image(type="pil", label="Upload Image")
# profile_select = gr.Dropdown(
# choices=sorted(list(PREDEFINED_CLASSES.keys())),
# value="casual",
# label="Detection Profile"
# )
# confidence_slider = gr.Slider(
# minimum=0.001, maximum=1.0, value=0.01, step=0.001,
# label="Confidence Threshold"
# )
# iou_slider = gr.Slider(
# minimum=0.01, maximum=1.0, value=0.2, step=0.01,
# label="IoU Threshold (NMS)"
# )
# detect_btn = gr.Button("Detect Objects", variant="primary") # Make button primary
# with gr.Column(scale=2): # Output column larger
# image_output = gr.Image(label="Detection Result", interactive=False) # Output not interactive
# labels_output = gr.Textbox(label="Detected Objects", lines=10, interactive=False)
# # Ensure the click event is correctly wired
# detect_btn.click(
# fn=detect_objects_ui,
# inputs=[image_input, profile_select, confidence_slider, iou_slider],
# outputs=[image_output, labels_output]
# )
# # Mount both FastAPI and Gradio
# # Ensure the Gradio app uses the FastAPI instance `app`
# app = gr.mount_gradio_app(app, demo, path="/")
# # ... (rest of the file remains the same) ...
# if __name__ == "__main__":
# import uvicorn
# # Check if YOLO models initialized before starting server
# # Update check to use the new model variables
# if not profile_models or dynamic_model is None:
# logger.error(f"CRITICAL: One or more YOLO-World models ({MODEL_NAME}) failed to initialize. API endpoint /api/detect_objects might not work correctly.")
# # Decide if you want to exit or run with degraded functionality
# # exit(1) # Optional: exit if model loading fails
# else:
# logger.info("All required YOLO models initialized successfully.")
# print("Starting Uvicorn server on http://0.0.0.0:7860")
# uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
import uvicorn
# Ensure YOLO-World models from detection_utils are loaded (conceptual check)
# A more robust check would involve calling a function in detection_utils
# or checking the YOLOWORLD_MODELS dictionary directly if it were accessible here.
# For now, we rely on detection_utils to log errors if models fail to load.
# if not detection_utils.YOLOWORLD_MODELS: # This line would cause an error if detection_utils is not imported
# logger.error("CRITICAL: YOLO-World models from detection_utils may not have initialized.")
# else:
# logger.info("YOLO-World models in detection_utils assumed to be loading/loaded.")
# A simple check based on previous logic (profile_models and dynamic_model were for a different setup)
# We can infer model readiness by checking if the PREDEFINED_CLASSES (used by detection_utils) has keys.
if not PREDEFINED_CLASSES: # A basic check, actual model loading is in detection_utils
logger.error(f"CRITICAL: PREDEFINED_CLASSES is empty. YOLO-World models might not be configured in detection_utils.")
else:
logger.info("Detection profiles are configured. YOLO-World model loading is handled in detection_utils.")
print("Starting Uvicorn server on http://0.0.0.0:7860")
uvicorn.run(app, host="0.0.0.0", port=7860)