|
import numpy as np |
|
from PIL import Image |
|
from typing import List, Dict, Any, Set, Tuple |
|
import os |
|
import tempfile |
|
from googletrans import Translator |
|
import cv2 |
|
from inference.models.yolo_world.yolo_world import YOLOWorld |
|
import onnxruntime as ort |
|
import requests |
|
import random |
|
|
|
|
|
|
|
original_inference_session = ort.InferenceSession |
|
|
|
def patched_inference_session(*args, **kwargs): |
|
kwargs["providers"] = ["CPUExecutionProvider"] |
|
return original_inference_session(*args, **kwargs) |
|
|
|
ort.InferenceSession = patched_inference_session |
|
import warnings |
|
warnings.filterwarnings("ignore", category=UserWarning, module="onnxruntime") |
|
|
|
PREDEFINED_CLASSES = { |
|
"tourist": [ |
|
"person", "car", "bus", "train", "truck", "boat", "traffic light", |
|
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", |
|
"dog", "backpack", "umbrella", "handbag", "tie", "suitcase", "building", |
|
"signboard", "taxi", "rickshaw", "camera", "map", "monument", "souvenir", |
|
"statue", "fountain", "street sign", "tour guide", "hotel", "restaurant", |
|
|
|
"temple", "mosque", "church", "fort", "palace", "museum", "market", "bazaar", |
|
"auto rickshaw", "cycle rickshaw", "metro", "heritage site", |
|
"ticket counter", "luggage", "water bottle", "scarf", "hat","bus stop", |
|
"information center", "shopping bag", "vendor", "street food", "food stall", |
|
"hawker", "street performer", "camel", "elephant ride", "tour bus", "minaret", |
|
"gopuram", "chhatri", "ghat", "river", "lake", "bridge", "park", "garden", |
|
|
|
], |
|
"casual": [ |
|
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", |
|
"truck", "boat", "traffic light", "fire hydrant", "stop sign", |
|
"parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", |
|
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", |
|
"handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", |
|
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", |
|
"surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", |
|
"knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", |
|
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", |
|
"couch", "potted plant", "bed", "dining table", "toilet", "tv", |
|
"laptop", "mouse", "tv remote","remote control", "keyboard", "cell phone", "microwave", |
|
"oven", "toaster", "book", "clock", |
|
"scissors", "teddy bear", "toothbrush", "tree", "flower", "park", |
|
"computer", "desk", "window", "door", |
|
|
|
"auto rickshaw", "cycle rickshaw", "scooter", "tempo", "tractor", "e-rickshaw", |
|
"delivery van", "ambulance", "police car", "roadside stall", "food cart", |
|
"street vendor", "helmet", "road sign", "speed breaker", |
|
"divider", "pothole", "bus stop", "petrol pump", |
|
"water dispenser", "printer", "file cabinet", "whiteboard", "projector", |
|
"security guard", "id card", "notice board", |
|
"elevator", "staircase", "canteen", "cafeteria", "tea cup", "tiffin box", |
|
"lunch box", "stationery", "pen", "notebook", "marker", "mouse pad" |
|
], |
|
"kids": [ |
|
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", |
|
"truck", "boat", "bird", "cat", "dog", "horse", "sheep", "cow", |
|
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", |
|
"handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", |
|
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", |
|
"surfboard", "tennis racket", "bottle", "banana", "apple", "sandwich", |
|
"orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", |
|
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", |
|
"laptop", "mouse", "remote", "keyboard", "cell phone", "book", "clock", |
|
"scissors", "teddy bear", "hair drier", "toothbrush", "red", "blue", |
|
"green", "yellow", "orange", "purple", "pink", "black", "white", "gray", |
|
"brown", "circle", "square", "triangle", "rectangle", "star", "heart", |
|
"ball", "block", "toy", "doll", "crayon", "slide", "swing", "duck", "lion", |
|
"tiger", "monkey", "moon", "sun", "cloud", "rainbow", |
|
|
|
"cylinder", "rectangular prism", "pyramid", "cube", "cone", "sphere", "triangular prism" |
|
] |
|
} |
|
|
|
|
|
|
|
|
|
SYNONYM_MAP = { |
|
"rickshaw": ["tuk-tuk", "auto rickshaw"], |
|
|
|
"motorbike": ["motorcycle"], |
|
"automobile": ["car"], |
|
} |
|
|
|
|
|
ORIGINAL_TERM_MAP = {} |
|
for original, synonyms in SYNONYM_MAP.items(): |
|
for synonym in synonyms: |
|
ORIGINAL_TERM_MAP[synonym] = original |
|
|
|
def expand_synonyms(class_list: List[str]) -> List[str]: |
|
"""Expands a list of class names with predefined synonyms.""" |
|
expanded_set = set(class_list) |
|
for term in class_list: |
|
if term in SYNONYM_MAP: |
|
expanded_set.update(SYNONYM_MAP[term]) |
|
return sorted(list(expanded_set)) |
|
|
|
|
|
|
|
|
|
YOLOWORLD_MODELS = {} |
|
for profile in PREDEFINED_CLASSES.keys(): |
|
YOLOWORLD_MODELS[profile] = YOLOWorld(model_id="yolo_world/l") |
|
YOLOWORLD_MODELS[profile].set_classes(PREDEFINED_CLASSES[profile]) |
|
|
|
|
|
def patch_requests_with_proxy(): |
|
proxies_path = os.path.join(os.path.dirname(__file__), "proxies.txt") |
|
try: |
|
with open(proxies_path, "r") as f: |
|
proxies = [line.strip() for line in f if line.strip()] |
|
if proxies: |
|
proxy = random.choice(proxies) |
|
proxy_url = f"http://{proxy}" |
|
requests.Session.proxies = { |
|
"http": proxy_url, |
|
"https": proxy_url |
|
} |
|
except Exception: |
|
pass |
|
|
|
|
|
def translate_text(text, dest_lang): |
|
patch_requests_with_proxy() |
|
try: |
|
|
|
translator = Translator(service_urls=['translate.googleapis.com']) |
|
result = translator.translate(text, dest=dest_lang) |
|
return result.text |
|
except Exception: |
|
return text |
|
|
|
|
|
def process_yoloworld_results(predictions, original_w: int, original_h: int, scale: float, pad_top: int, pad_left: int, class_filter=None, target_language="en"): |
|
""" |
|
Process YOLO-World predictions to match the expected output format. |
|
Transforms coordinates back to the original image space. |
|
Only translate label if target_language is not English. |
|
""" |
|
detections = [] |
|
for pred in predictions: |
|
class_name = pred.class_name |
|
if class_filter and class_name not in class_filter: |
|
continue |
|
|
|
|
|
|
|
box_center_x_padded = float(pred.x) |
|
box_center_y_padded = float(pred.y) |
|
box_width_padded = float(pred.width) |
|
box_height_padded = float(pred.height) |
|
|
|
|
|
box_center_x_resized = box_center_x_padded - pad_left |
|
box_center_y_resized = box_center_y_padded - pad_top |
|
|
|
|
|
original_center_x = box_center_x_resized / scale |
|
original_center_y = box_center_y_resized / scale |
|
original_width = box_width_padded / scale |
|
original_height = box_height_padded / scale |
|
|
|
|
|
original_x1 = original_center_x - (original_width / 2) |
|
original_y1 = original_center_y - (original_height / 2) |
|
|
|
|
|
original_x1 = max(0, min(original_x1, original_w)) |
|
original_y1 = max(0, min(original_y1, original_h)) |
|
|
|
original_x2 = max(0, min(original_x1 + original_width, original_w)) |
|
original_y2 = max(0, min(original_y1 + original_height, original_h)) |
|
|
|
final_width = original_x2 - original_x1 |
|
final_height = original_y2 - original_y1 |
|
final_center_x = original_x1 + final_width / 2 |
|
final_center_y = original_y1 + final_height / 2 |
|
|
|
|
|
if target_language and target_language.lower() != "en": |
|
label_translated = translate_text(class_name, target_language) |
|
else: |
|
label_translated = class_name |
|
|
|
detections.append({ |
|
"box": [int(original_x1), int(original_y1), int(final_width), int(final_height)], |
|
"confidence": float(pred.confidence), |
|
"label": label_translated, |
|
"label_en": class_name, |
|
"centre": [int(final_center_x), int(final_center_y)] |
|
}) |
|
return detections |
|
|
|
|
|
def run_yoloworld_detection(image: Image.Image, target_classes: set, confidence_threshold: float = 0.1, iou_threshold: float = 0.4, profile: str = "casual", target_language: str = "en"): |
|
"""Run YOLO-World detection for the given profile and filter by target classes.""" |
|
model = YOLOWORLD_MODELS.get(profile, list(YOLOWORLD_MODELS.values())[0]) |
|
|
|
|
|
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
|
|
|
|
target_size = 640 |
|
h, w = image_cv.shape[:2] |
|
scale = target_size / max(h, w) |
|
new_w, new_h = int(w * scale), int(h * scale) |
|
|
|
resized_image_cv = cv2.resize(image_cv, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
delta_w = target_size - new_w |
|
delta_h = target_size - new_h |
|
top, bottom = delta_h // 2, delta_h - (delta_h // 2) |
|
left, right = delta_w // 2, delta_w - (delta_w // 2) |
|
|
|
padded_image_cv = cv2.copyMakeBorder(resized_image_cv, top, bottom, left, right, |
|
cv2.BORDER_CONSTANT, value=[114, 114, 114]) |
|
|
|
|
|
results = model.infer(padded_image_cv, confidence=confidence_threshold, iou=iou_threshold, providers=["CPUExecutionProvider"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
detections = process_yoloworld_results(results.predictions, w, h, scale, top, left, class_filter=target_classes, target_language=target_language) |
|
return detections |