import logging from typing import List, Dict import torch.multiprocessing as mp from sentence_transformers import CrossEncoder import config # Used for CUDA multiprocessing # mp.set_start_method('spawn', force=True) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) console_handler.setFormatter(formatter) logger = logging.getLogger() logger.addHandler(console_handler) logger.setLevel(logging.DEBUG) class EndpointHandler: def __init__(self, path): self.top_k = config.CROSSENCODER_TOP_K self.max_length = config.CROSSENCODER_MAX_LENGTH self.max_workers = config.CROSSENCODER_MAX_WORKERS self.model = CrossEncoder( path, max_length=self.max_length ) logger.info( f'Loaded {path}:top_k={self.top_k}:length={self.max_length}:workers={self.max_workers}') def __call__(self, data: Dict[str, str]) -> List[Dict[str, int | float]]: logger.info(f'Received new docs to rerank: {data}') inputs = data.pop('inputs') ranks = self.model.rank( inputs['query'], inputs['documents'], top_k=self.top_k, # num_workers=self.max_workers ) logger.info(f'New ranks: {ranks}') logger.info(f'Computed ranks for each document: {ranks}') return ranks