reranker / handler.py
radu.mutilica
removed mp
3e7aea0
import logging
from typing import List, Dict
import torch.multiprocessing as mp
from sentence_transformers import CrossEncoder
import config
# Used for CUDA multiprocessing
# mp.set_start_method('spawn', force=True)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(console_handler)
logger.setLevel(logging.DEBUG)
class EndpointHandler:
def __init__(self, path):
self.top_k = config.CROSSENCODER_TOP_K
self.max_length = config.CROSSENCODER_MAX_LENGTH
self.max_workers = config.CROSSENCODER_MAX_WORKERS
self.model = CrossEncoder(
path,
max_length=self.max_length
)
logger.info(
f'Loaded {path}:top_k={self.top_k}:length={self.max_length}:workers={self.max_workers}')
def __call__(self, data: Dict[str, str]) -> List[Dict[str, int | float]]:
logger.info(f'Received new docs to rerank: {data}')
inputs = data.pop('inputs')
ranks = self.model.rank(
inputs['query'],
inputs['documents'],
top_k=self.top_k,
# num_workers=self.max_workers
)
logger.info(f'New ranks: {ranks}')
logger.info(f'Computed ranks for each document: {ranks}')
return ranks