|
import logging |
|
from typing import List, Dict |
|
|
|
import torch.multiprocessing as mp |
|
from sentence_transformers import CrossEncoder |
|
|
|
import config |
|
|
|
|
|
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.DEBUG) |
|
console_handler.setFormatter(formatter) |
|
logger = logging.getLogger() |
|
logger.addHandler(console_handler) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path): |
|
self.top_k = config.CROSSENCODER_TOP_K |
|
self.max_length = config.CROSSENCODER_MAX_LENGTH |
|
self.max_workers = config.CROSSENCODER_MAX_WORKERS |
|
|
|
self.model = CrossEncoder( |
|
path, |
|
max_length=self.max_length |
|
) |
|
logger.info( |
|
f'Loaded {path}:top_k={self.top_k}:length={self.max_length}:workers={self.max_workers}') |
|
|
|
def __call__(self, data: Dict[str, str]) -> List[Dict[str, int | float]]: |
|
logger.info(f'Received new docs to rerank: {data}') |
|
|
|
inputs = data.pop('inputs') |
|
ranks = self.model.rank( |
|
inputs['query'], |
|
inputs['documents'], |
|
top_k=self.top_k, |
|
|
|
) |
|
logger.info(f'New ranks: {ranks}') |
|
|
|
logger.info(f'Computed ranks for each document: {ranks}') |
|
return ranks |
|
|