Fixed some minor bugs in eval_mteb.py
Browse files1. `task_cls` has no field `description`, change it to `metadata_dict`
2. Misspelling for `mteb-fr`, `tas_names` -> `task_names`
3. No assignment of the `task_names` variable for `mteb-pl`
4. Removed unused imports
4. Used ruff for code formatting
- scripts/eval_mteb.py +321 -216
scripts/eval_mteb.py
CHANGED
|
@@ -1,21 +1,18 @@
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
-
from collections import defaultdict
|
| 3 |
-
import json
|
| 4 |
import logging
|
| 5 |
import math
|
| 6 |
-
import os
|
| 7 |
-
import sys
|
| 8 |
import queue
|
| 9 |
from typing import Dict, List, Optional, Union
|
| 10 |
|
| 11 |
-
from tqdm.autonotebook import trange
|
| 12 |
-
import datasets
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
| 15 |
import torch.multiprocessing as mp
|
|
|
|
| 16 |
from transformers import AutoModel, AutoTokenizer
|
| 17 |
-
|
| 18 |
-
from mteb import MTEB
|
| 19 |
|
| 20 |
TASK_LIST_CLASSIFICATION = [
|
| 21 |
"AmazonCounterfactualClassification",
|
|
@@ -112,99 +109,179 @@ MTEB_TASK_LIST = (
|
|
| 112 |
)
|
| 113 |
|
| 114 |
|
| 115 |
-
CMTEB_TASK_LIST = [
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
MTEB_PL = [
|
| 123 |
-
"CBD",
|
| 124 |
-
"
|
| 125 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
]
|
| 127 |
|
| 128 |
MTEB_FR = [
|
| 129 |
-
"AmazonReviewsClassification",
|
| 130 |
-
"
|
| 131 |
-
"
|
| 132 |
-
"
|
| 133 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
]
|
| 135 |
|
| 136 |
logging.basicConfig(
|
| 137 |
-
level=logging.INFO,
|
| 138 |
-
format='%(asctime)s - %(levelname)s - %(name)s : %(message)s'
|
| 139 |
)
|
| 140 |
|
| 141 |
-
logger = logging.getLogger(
|
|
|
|
| 142 |
|
| 143 |
def get_detailed_instruct(task_description: str) -> str:
|
| 144 |
if not task_description:
|
| 145 |
-
return
|
| 146 |
|
| 147 |
-
return
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
return "Retrieve semantically similar text"
|
| 152 |
|
| 153 |
-
if task_type in [
|
| 154 |
return "Given a news summary, retrieve other semantically similar summaries"
|
| 155 |
|
| 156 |
-
if task_type in [
|
| 157 |
return "Retrieve parallel sentences"
|
| 158 |
|
| 159 |
-
if task_type in [
|
| 160 |
task_name_to_instruct: Dict[str, str] = {
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
# C-MTEB eval instructions
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
# MTEB-pl eval instructions
|
| 181 |
-
"CBD":"Classify the sentiment of polish tweet reviews",
|
| 182 |
"PolEmo2.0-IN": "Classify the sentiment of in-domain (medicine and hotels) online reviews",
|
| 183 |
-
"PolEmo2.0-OUT":"Classify the sentiment of out-of-domain (products and school) online reviews",
|
| 184 |
"AllegroReviews": "Classify the sentiment of reviews from e-commerce marketplace Allegro",
|
| 185 |
-
"PAC":
|
| 186 |
-
|
| 187 |
}
|
| 188 |
return task_name_to_instruct[task_name]
|
| 189 |
|
| 190 |
-
if task_type in [
|
| 191 |
task_name_to_instruct: Dict[str, str] = {
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
# C-MTEB eval instructions
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
# MTEB-fr eval instructions
|
| 209 |
"AlloProfClusteringP2P": "Identify the main category of Allo Prof document based on the titles and descriptions",
|
| 210 |
"AlloProfClusteringS2S": "Identify the main category of Allo Prof document based on the titles",
|
|
@@ -212,32 +289,32 @@ def get_task_def_by_task_name_and_type(task_name: str, task_type: str, default_i
|
|
| 212 |
"MasakhaNEWSClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents",
|
| 213 |
"MasakhaNEWSClusteringS2S": "Identify the topic or theme of the given news articles based on the titles",
|
| 214 |
"MLSUMClusteringP2P": "Identify the topic or theme of the given articles based on the titles and contents",
|
| 215 |
-
"MLSUMClusteringS2S":
|
| 216 |
# MTEB-pl eval instructions
|
| 217 |
"8TagsClustering": "Identify of headlines from social media posts in Polish into 8 categories: film, history, food, medicine, motorization, work, sport and technology",
|
| 218 |
}
|
| 219 |
return task_name_to_instruct[task_name]
|
| 220 |
|
| 221 |
-
if task_type in [
|
| 222 |
task_name_to_instruct: Dict[str, str] = {
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
# C-MTEB eval instructions
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
# MTEB-fr eval instructions
|
| 238 |
"AlloprofReranking": "Given a question, retrieve passages that answer the question",
|
| 239 |
-
"OpusparcusPC":"Retrieve semantically similar text",
|
| 240 |
-
"PawsX":"Retrieve semantically similar text",
|
| 241 |
"SyntecReranking": "Given a question, retrieve passages that answer the question",
|
| 242 |
# MTEB-pl eval instructions
|
| 243 |
"SICK-E-PL": "Retrieve semantically similar text",
|
|
@@ -247,41 +324,41 @@ def get_task_def_by_task_name_and_type(task_name: str, task_type: str, default_i
|
|
| 247 |
}
|
| 248 |
return task_name_to_instruct[task_name]
|
| 249 |
|
| 250 |
-
if task_type in [
|
| 251 |
-
if task_name.lower().startswith(
|
| 252 |
-
return
|
| 253 |
|
| 254 |
task_name_to_instruct: Dict[str, str] = {
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
# C-MTEB eval instructions
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
# MTEB-fr eval instructions
|
| 279 |
"AlloprofRetrieval": "Given a question, retrieve passages that answer the question",
|
| 280 |
"BSARDRetrieval": "Given a question, retrieve passages that answer the question",
|
| 281 |
"SyntecRetrieval": "Given a question, retrieve passages that answer the question",
|
| 282 |
"XPQARetrieval": "Given a question, retrieve passages that answer the question",
|
| 283 |
"MintakaRetrieval": "Given a question, retrieve passages that answer the question",
|
| 284 |
-
# MTEB-pl eval instructions
|
| 285 |
"ArguAna-PL": "Given a claim, find documents that refute the claim",
|
| 286 |
"DBPedia-PL": "Given a query, retrieve relevant entity descriptions from DBPedia",
|
| 287 |
"FiQA-PL": "Given a financial question, retrieve user replies that best answer the question",
|
|
@@ -292,45 +369,47 @@ def get_task_def_by_task_name_and_type(task_name: str, task_type: str, default_i
|
|
| 292 |
"Quora-PL": "Given a question, retrieve questions that are semantically equivalent to the given question",
|
| 293 |
"SCIDOCS-PL": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper",
|
| 294 |
"SciFact-PL": "Given a scientific claim, retrieve documents that support or refute the claim",
|
| 295 |
-
"TRECCOVID-PL": "Given a query on COVID-19, retrieve documents that answer the query"
|
| 296 |
}
|
| 297 |
|
| 298 |
# add lower case keys to match some beir names
|
| 299 |
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
|
| 300 |
# other cases where lower case match still doesn't work
|
| 301 |
-
task_name_to_instruct[
|
| 302 |
-
task_name_to_instruct[
|
| 303 |
-
task_name_to_instruct[
|
| 304 |
-
task_name_to_instruct[
|
| 305 |
-
task_name_to_instruct[
|
| 306 |
-
task_name_to_instruct[
|
| 307 |
|
| 308 |
# for miracl evaluation
|
| 309 |
-
task_name_to_instruct[
|
|
|
|
|
|
|
| 310 |
|
| 311 |
return task_name_to_instruct[task_name]
|
| 312 |
-
logging.warning(
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
class Encoder(torch.nn.Module):
|
| 316 |
-
def __init__(self, name_or_path:str, pooling: str):
|
| 317 |
super().__init__()
|
| 318 |
self.model = AutoModel.from_pretrained(name_or_path, trust_remote_code=True)
|
| 319 |
self.model = self.model.half()
|
| 320 |
-
self.model.eval()
|
| 321 |
self.pooling = pooling
|
| 322 |
|
| 323 |
def forward(self, **features) -> torch.Tensor:
|
| 324 |
output = self.model(**features, output_hidden_states=True, return_dict=True)
|
| 325 |
-
hidden_state = output.hidden_states[-1]
|
| 326 |
embeddings = self.pooler(hidden_state, **features)
|
| 327 |
return embeddings
|
| 328 |
|
| 329 |
def pooler(
|
| 330 |
-
self,
|
| 331 |
-
hidden_state: torch.Tensor,
|
| 332 |
-
attention_mask: torch.Tensor,
|
| 333 |
-
**kwargs
|
| 334 |
) -> torch.Tensor:
|
| 335 |
if attention_mask.ndim == 2:
|
| 336 |
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size())
|
|
@@ -341,32 +420,35 @@ class Encoder(torch.nn.Module):
|
|
| 341 |
|
| 342 |
hidden_state = hidden_state * mask_expanded
|
| 343 |
|
| 344 |
-
if self.pooling ==
|
| 345 |
pooled_output = hidden_state[:, 0]
|
| 346 |
|
| 347 |
-
elif self.pooling ==
|
| 348 |
-
left_padding =
|
| 349 |
if left_padding:
|
| 350 |
return hidden_state[:, -1]
|
| 351 |
else:
|
| 352 |
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 353 |
batch_size = hidden_state.shape[0]
|
| 354 |
-
return hidden_state[
|
| 355 |
-
|
|
|
|
|
|
|
| 356 |
# TODO: weight
|
| 357 |
lengths = mask_expanded.sum(1).clamp(min=1e-9)
|
| 358 |
pooled_output = hidden_state.sum(dim=1) / lengths
|
| 359 |
|
| 360 |
-
elif self.pooling ==
|
| 361 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
|
| 362 |
# hidden_state shape: bs, seq, hidden_dim
|
| 363 |
weights = (
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
)
|
|
|
|
| 370 |
assert weights.shape == hidden_state.shape == input_mask_expanded.shape
|
| 371 |
input_mask_expanded = input_mask_expanded * weights
|
| 372 |
|
|
@@ -392,28 +474,29 @@ class Wrapper:
|
|
| 392 |
force_default: bool = False,
|
| 393 |
sep: str = " ",
|
| 394 |
mp_tensor_to_cuda: bool = False,
|
| 395 |
-
instruction: str = None,
|
| 396 |
-
attn_type: str = None
|
| 397 |
):
|
| 398 |
self.tokenizer = tokenizer
|
| 399 |
self.model = encoder
|
| 400 |
self.batch_size = batch_size
|
| 401 |
self.max_seq_len = max_seq_len
|
| 402 |
-
self.pool: dict = None
|
| 403 |
self.normalize_embeddings = normalize_embeddings
|
| 404 |
self.mp_tensor_to_cuda = mp_tensor_to_cuda
|
| 405 |
self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 406 |
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
| 407 |
self.instruction = instruction
|
| 408 |
-
self.default_query = default_query
|
| 409 |
self.sep = sep
|
| 410 |
self.force_default = force_default
|
| 411 |
-
if self.tokenizer.padding_side !=
|
| 412 |
-
logger.warning(
|
| 413 |
-
|
|
|
|
|
|
|
| 414 |
if self.tokenizer.pad_token is None:
|
| 415 |
logger.warning(f"Set tokenizer.pad_token as eos_token {self.tokenizer.eos_token}")
|
| 416 |
-
self.tokenizer.pad_token=
|
| 417 |
|
| 418 |
def start(self, target_devices: Optional[List[str]] = None):
|
| 419 |
"""
|
|
@@ -426,14 +509,16 @@ class Wrapper:
|
|
| 426 |
"""
|
| 427 |
if target_devices is None:
|
| 428 |
if torch.cuda.is_available():
|
| 429 |
-
target_devices = [
|
| 430 |
else:
|
| 431 |
logger.info("CUDA is not available. Start 4 CPU worker")
|
| 432 |
-
target_devices = [
|
| 433 |
|
| 434 |
-
logger.info(
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
| 437 |
input_queue = ctx.Queue()
|
| 438 |
output_queue = ctx.Queue()
|
| 439 |
processes = []
|
|
@@ -442,26 +527,26 @@ class Wrapper:
|
|
| 442 |
p = ctx.Process(
|
| 443 |
target=self._encode_multi_process_worker,
|
| 444 |
args=(cuda_id, self, input_queue, output_queue),
|
| 445 |
-
daemon=True
|
| 446 |
)
|
| 447 |
p.start()
|
| 448 |
processes.append(p)
|
| 449 |
|
| 450 |
-
self.pool = {
|
| 451 |
|
| 452 |
def stop(self):
|
| 453 |
"""
|
| 454 |
Stops all processes started with start_multi_process_pool
|
| 455 |
"""
|
| 456 |
-
for p in self.pool[
|
| 457 |
p.terminate()
|
| 458 |
|
| 459 |
-
for p in self.pool[
|
| 460 |
p.join()
|
| 461 |
p.close()
|
| 462 |
|
| 463 |
-
self.pool[
|
| 464 |
-
self.pool[
|
| 465 |
|
| 466 |
@staticmethod
|
| 467 |
def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue):
|
|
@@ -477,11 +562,7 @@ class Wrapper:
|
|
| 477 |
except queue.Empty:
|
| 478 |
break
|
| 479 |
|
| 480 |
-
def encode_multi_process(
|
| 481 |
-
self,
|
| 482 |
-
sentences: List[str],
|
| 483 |
-
**kwargs
|
| 484 |
-
):
|
| 485 |
"""
|
| 486 |
This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
|
| 487 |
and sent to individual processes, which encode these on the different GPUs. This method is only suitable
|
|
@@ -496,9 +577,11 @@ class Wrapper:
|
|
| 496 |
part_size = math.ceil(len(sentences) / len(self.pool["processes"]))
|
| 497 |
chunk_size = part_size if part_size < 3200 else 3200 # for retrieval chunk 50000
|
| 498 |
|
| 499 |
-
logger.debug(
|
|
|
|
|
|
|
| 500 |
|
| 501 |
-
input_queue = self.pool[
|
| 502 |
last_chunk_id = 0
|
| 503 |
chunk = []
|
| 504 |
|
|
@@ -513,8 +596,10 @@ class Wrapper:
|
|
| 513 |
input_queue.put([last_chunk_id, chunk, kwargs])
|
| 514 |
last_chunk_id += 1
|
| 515 |
|
| 516 |
-
output_queue = self.pool[
|
| 517 |
-
results_list = sorted(
|
|
|
|
|
|
|
| 518 |
embeddings = np.concatenate([result[1] for result in results_list])
|
| 519 |
return embeddings
|
| 520 |
|
|
@@ -535,33 +620,41 @@ class Wrapper:
|
|
| 535 |
(representing several text inputs to the model).
|
| 536 |
"""
|
| 537 |
|
| 538 |
-
if isinstance(text, dict):
|
| 539 |
return len(next(iter(text.values())))
|
| 540 |
-
elif not hasattr(text,
|
| 541 |
return 1
|
| 542 |
-
elif len(text) == 0 or isinstance(text[0], int):
|
| 543 |
return len(text)
|
| 544 |
else:
|
| 545 |
-
return sum([len(t) for t in text])
|
| 546 |
|
| 547 |
def _tokenize(self, sentences: List[str], is_query: bool):
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
return batch_dict
|
| 554 |
|
| 555 |
-
|
| 556 |
def _encode(
|
| 557 |
self,
|
| 558 |
sentences: List[str],
|
| 559 |
is_query: bool,
|
| 560 |
convert_to_numpy: bool = True,
|
| 561 |
convert_to_tensor: bool = False,
|
| 562 |
-
device: str = None,
|
| 563 |
show_progress_bar: bool = True,
|
| 564 |
-
**kwargs
|
| 565 |
):
|
| 566 |
"""
|
| 567 |
Computes sentence embeddings
|
|
@@ -584,7 +677,9 @@ class Wrapper:
|
|
| 584 |
convert_to_numpy = False
|
| 585 |
|
| 586 |
input_was_string = False
|
| 587 |
-
if isinstance(sentences, str) or not hasattr(
|
|
|
|
|
|
|
| 588 |
sentences = [sentences]
|
| 589 |
input_was_string = True
|
| 590 |
|
|
@@ -597,8 +692,10 @@ class Wrapper:
|
|
| 597 |
length_sorted_idx = np.argsort([-self._text_length(s) for s in sentences])
|
| 598 |
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
| 599 |
|
| 600 |
-
for start_index in trange(
|
| 601 |
-
|
|
|
|
|
|
|
| 602 |
features = self._tokenize(sentences_batch, is_query)
|
| 603 |
features = self.batch_to_device(features, device)
|
| 604 |
|
|
@@ -619,7 +716,7 @@ class Wrapper:
|
|
| 619 |
if convert_to_tensor:
|
| 620 |
all_embeddings = torch.stack(all_embeddings)
|
| 621 |
elif convert_to_numpy:
|
| 622 |
-
#all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
| 623 |
all_embeddings = np.asarray([emb.to(torch.float).numpy() for emb in all_embeddings])
|
| 624 |
if input_was_string:
|
| 625 |
all_embeddings = all_embeddings[0]
|
|
@@ -631,11 +728,11 @@ class Wrapper:
|
|
| 631 |
sentences: List[str],
|
| 632 |
is_query: Optional[bool] = None,
|
| 633 |
convert_to_tensor: bool = False,
|
| 634 |
-
**kwargs
|
| 635 |
):
|
| 636 |
is_query = self.default_query if is_query is None else is_query
|
| 637 |
if is_query and self.instruction:
|
| 638 |
-
|
| 639 |
kwargs.update(is_query=is_query)
|
| 640 |
if self.pool is not None:
|
| 641 |
kwargs.update(show_progress_bar=False)
|
|
@@ -643,7 +740,7 @@ class Wrapper:
|
|
| 643 |
if convert_to_tensor:
|
| 644 |
embeddings = torch.from_numpy(embeddings)
|
| 645 |
if self.mp_tensor_to_cuda and torch.cuda.is_available():
|
| 646 |
-
embeddings = embeddings.to(torch.device(
|
| 647 |
return embeddings
|
| 648 |
|
| 649 |
return self._encode(sentences, convert_to_tensor=convert_to_tensor, **kwargs)
|
|
@@ -663,7 +760,9 @@ class Wrapper:
|
|
| 663 |
]
|
| 664 |
elif isinstance(corpus[0], dict):
|
| 665 |
sentences = [
|
| 666 |
-
(doc["title"] + self.sep + doc["text"]).strip()
|
|
|
|
|
|
|
| 667 |
for doc in corpus
|
| 668 |
]
|
| 669 |
else:
|
|
@@ -671,43 +770,46 @@ class Wrapper:
|
|
| 671 |
is_query = self.default_query if self.force_default else False
|
| 672 |
return self.encode(sentences, is_query=is_query, **kwargs)
|
| 673 |
|
|
|
|
| 674 |
def main(args):
|
| 675 |
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 676 |
encoder = Encoder(args.model, args.pooling)
|
| 677 |
-
default_query = args.default_type ==
|
| 678 |
model = Wrapper(
|
| 679 |
-
tokenizer,
|
|
|
|
| 680 |
batch_size=args.batch_size,
|
| 681 |
max_seq_len=args.max_seq_len,
|
| 682 |
normalize_embeddings=args.norm,
|
| 683 |
-
default_query=default_query
|
| 684 |
)
|
| 685 |
-
sym_retrievals = [
|
| 686 |
-
if args.task ==
|
| 687 |
task_names = MTEB_TASK_LIST
|
| 688 |
-
lang = [
|
| 689 |
-
elif args.task ==
|
| 690 |
task_names = CMTEB_TASK_LIST
|
| 691 |
-
lang = [
|
| 692 |
-
elif args.task ==
|
| 693 |
-
|
| 694 |
-
lang = [
|
| 695 |
-
elif args.task ==
|
| 696 |
-
|
|
|
|
| 697 |
else:
|
| 698 |
task_names = [args.task]
|
| 699 |
-
lang = [
|
| 700 |
for task in task_names:
|
| 701 |
evaluation = MTEB(tasks=[task], task_langs=lang)
|
| 702 |
task_cls = evaluation.tasks[0]
|
| 703 |
-
task_name: str = task_cls.
|
| 704 |
-
task_type: str = task_cls.
|
| 705 |
instruction = get_task_def_by_task_name_and_type(task_name, task_type)
|
| 706 |
model.instruction = get_detailed_instruct(instruction)
|
| 707 |
-
if task ==
|
| 708 |
eval_splits = ["dev"]
|
| 709 |
elif task in CMTEB_TASK_LIST:
|
| 710 |
-
eval_splits = task_cls.
|
| 711 |
else:
|
| 712 |
eval_splits = ["test"]
|
| 713 |
sym = False
|
|
@@ -718,28 +820,31 @@ def main(args):
|
|
| 718 |
else:
|
| 719 |
sym = False
|
| 720 |
if sym:
|
| 721 |
-
logger.info(
|
|
|
|
|
|
|
| 722 |
model.force_default = True
|
| 723 |
evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
|
| 724 |
|
| 725 |
if sym:
|
| 726 |
logger.info(f"Switch back.")
|
| 727 |
model.force_default = force_default_ori
|
| 728 |
-
print(
|
| 729 |
|
| 730 |
|
| 731 |
if __name__ == "__main__":
|
| 732 |
_PARSER = argparse.ArgumentParser()
|
| 733 |
-
_PARSER.add_argument(
|
| 734 |
-
|
| 735 |
-
)
|
| 736 |
-
_PARSER.add_argument("--pooling", type=str, default='last')
|
| 737 |
_PARSER.add_argument("--output_dir", type=str, default=None)
|
| 738 |
-
_PARSER.add_argument("--default_type", type=str, default=
|
| 739 |
_PARSER.add_argument("--max_seq_len", type=int, default=512)
|
| 740 |
_PARSER.add_argument("-b", "--batch_size", type=int, default=32)
|
| 741 |
_PARSER.add_argument(
|
| 742 |
-
"-t",
|
|
|
|
|
|
|
|
|
|
| 743 |
)
|
| 744 |
_PARSER.add_argument("--norm", action="store_true")
|
| 745 |
_ARGS = _PARSER.parse_args()
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import argparse
|
|
|
|
|
|
|
| 4 |
import logging
|
| 5 |
import math
|
|
|
|
|
|
|
| 6 |
import queue
|
| 7 |
from typing import Dict, List, Optional, Union
|
| 8 |
|
|
|
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
import torch.multiprocessing as mp
|
| 12 |
+
from tqdm.autonotebook import trange
|
| 13 |
from transformers import AutoModel, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
from mteb import MTEB
|
| 16 |
|
| 17 |
TASK_LIST_CLASSIFICATION = [
|
| 18 |
"AmazonCounterfactualClassification",
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
|
| 112 |
+
CMTEB_TASK_LIST = [
|
| 113 |
+
"TNews",
|
| 114 |
+
"IFlyTek",
|
| 115 |
+
"MultilingualSentiment",
|
| 116 |
+
"JDReview",
|
| 117 |
+
"OnlineShopping",
|
| 118 |
+
"Waimai",
|
| 119 |
+
"AmazonReviewsClassification",
|
| 120 |
+
"MassiveIntentClassification",
|
| 121 |
+
"MassiveScenarioClassification",
|
| 122 |
+
"MultilingualSentiment",
|
| 123 |
+
"CLSClusteringS2S",
|
| 124 |
+
"CLSClusteringP2P",
|
| 125 |
+
"ThuNewsClusteringS2S",
|
| 126 |
+
"ThuNewsClusteringP2P",
|
| 127 |
+
"Ocnli",
|
| 128 |
+
"Cmnli",
|
| 129 |
+
"T2Reranking",
|
| 130 |
+
"MmarcoReranking",
|
| 131 |
+
"CMedQAv1",
|
| 132 |
+
"CMedQAv2",
|
| 133 |
+
"T2Retrieval",
|
| 134 |
+
"MMarcoRetrieval",
|
| 135 |
+
"DuRetrieval",
|
| 136 |
+
"CovidRetrieval",
|
| 137 |
+
"CmedqaRetrieval",
|
| 138 |
+
"EcomRetrieval",
|
| 139 |
+
"MedicalRetrieval",
|
| 140 |
+
"VideoRetrieval",
|
| 141 |
+
"ATEC",
|
| 142 |
+
"BQ",
|
| 143 |
+
"LCQMC",
|
| 144 |
+
"PAWSX",
|
| 145 |
+
"STSB",
|
| 146 |
+
"AFQMC",
|
| 147 |
+
"QBQTC",
|
| 148 |
+
"STS22",
|
| 149 |
+
]
|
| 150 |
|
| 151 |
MTEB_PL = [
|
| 152 |
+
"CBD",
|
| 153 |
+
"PolEmo2.0-IN",
|
| 154 |
+
"PolEmo2.0-OUT",
|
| 155 |
+
"AllegroReviews",
|
| 156 |
+
"PAC",
|
| 157 |
+
"MassiveIntentClassification",
|
| 158 |
+
"MassiveScenarioClassification",
|
| 159 |
+
"SICK-E-PL",
|
| 160 |
+
"PPC",
|
| 161 |
+
"CDSC-E",
|
| 162 |
+
"PSC",
|
| 163 |
+
"8TagsClustering",
|
| 164 |
+
"SICK-R-PL",
|
| 165 |
+
"CDSC-R",
|
| 166 |
+
"STS22",
|
| 167 |
+
"ArguAna-PL",
|
| 168 |
+
"DBPedia-PL",
|
| 169 |
+
"FiQA-PL",
|
| 170 |
+
"HotpotQA-PL",
|
| 171 |
+
"MSMARCO-PL",
|
| 172 |
+
"NFCorpus-PL",
|
| 173 |
+
"NQ-PL",
|
| 174 |
+
"Quora-PL",
|
| 175 |
+
"SCIDOCS-PL",
|
| 176 |
+
"SciFact-PL",
|
| 177 |
+
"TRECCOVID-PL",
|
| 178 |
]
|
| 179 |
|
| 180 |
MTEB_FR = [
|
| 181 |
+
"AmazonReviewsClassification",
|
| 182 |
+
"MasakhaNEWSClassification",
|
| 183 |
+
"MassiveIntentClassification",
|
| 184 |
+
"MassiveScenarioClassification",
|
| 185 |
+
"MTOPDomainClassification",
|
| 186 |
+
"MTOPIntentClassification",
|
| 187 |
+
"OpusparcusPC",
|
| 188 |
+
"PawsX",
|
| 189 |
+
"AlloProfClusteringP2P",
|
| 190 |
+
"AlloProfClusteringS2S",
|
| 191 |
+
"HALClusteringS2S",
|
| 192 |
+
"MasakhaNEWSClusteringP2P",
|
| 193 |
+
"MasakhaNEWSClusteringS2S",
|
| 194 |
+
"MLSUMClusteringP2P",
|
| 195 |
+
"MLSUMClusteringS2S",
|
| 196 |
+
"SyntecReranking",
|
| 197 |
+
"AlloprofReranking",
|
| 198 |
+
"AlloprofRetrieval",
|
| 199 |
+
"BSARDRetrieval",
|
| 200 |
+
"SyntecRetrieval",
|
| 201 |
+
"XPQARetrieval",
|
| 202 |
+
"MintakaRetrieval",
|
| 203 |
+
"SummEvalFr",
|
| 204 |
+
"STSBenchmarkMultilingualSTS",
|
| 205 |
+
"STS22",
|
| 206 |
+
"SICKFr",
|
| 207 |
]
|
| 208 |
|
| 209 |
logging.basicConfig(
|
| 210 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s : %(message)s"
|
|
|
|
| 211 |
)
|
| 212 |
|
| 213 |
+
logger = logging.getLogger("eval_mteb_qwen.py")
|
| 214 |
+
|
| 215 |
|
| 216 |
def get_detailed_instruct(task_description: str) -> str:
|
| 217 |
if not task_description:
|
| 218 |
+
return ""
|
| 219 |
|
| 220 |
+
return "Instruct: {}\nQuery: ".format(task_description)
|
| 221 |
|
| 222 |
+
|
| 223 |
+
def get_task_def_by_task_name_and_type(
|
| 224 |
+
task_name: str,
|
| 225 |
+
task_type: str,
|
| 226 |
+
default_instruct="Given a web search query, retrieve relevant passages that answer the query",
|
| 227 |
+
) -> str:
|
| 228 |
+
if task_type in ["STS"]:
|
| 229 |
return "Retrieve semantically similar text"
|
| 230 |
|
| 231 |
+
if task_type in ["Summarization"]:
|
| 232 |
return "Given a news summary, retrieve other semantically similar summaries"
|
| 233 |
|
| 234 |
+
if task_type in ["BitextMining"]:
|
| 235 |
return "Retrieve parallel sentences"
|
| 236 |
|
| 237 |
+
if task_type in ["Classification"]:
|
| 238 |
task_name_to_instruct: Dict[str, str] = {
|
| 239 |
+
"AmazonCounterfactualClassification": "Classify a given Amazon customer review text as either counterfactual or not-counterfactual",
|
| 240 |
+
"AmazonPolarityClassification": "Classify Amazon reviews into positive or negative sentiment",
|
| 241 |
+
"AmazonReviewsClassification": "Classify the given Amazon review into its appropriate rating category",
|
| 242 |
+
"Banking77Classification": "Given a online banking query, find the corresponding intents",
|
| 243 |
+
"EmotionClassification": "Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise",
|
| 244 |
+
"ImdbClassification": "Classify the sentiment expressed in the given movie review text from the IMDB dataset",
|
| 245 |
+
"MassiveIntentClassification": "Given a user utterance as query, find the user intents",
|
| 246 |
+
"MassiveScenarioClassification": "Given a user utterance as query, find the user scenarios",
|
| 247 |
+
"MTOPDomainClassification": "Classify the intent domain of the given utterance in task-oriented conversation",
|
| 248 |
+
"MTOPIntentClassification": "Classify the intent of the given utterance in task-oriented conversation",
|
| 249 |
+
"ToxicConversationsClassification": "Classify the given comments as either toxic or not toxic",
|
| 250 |
+
"TweetSentimentExtractionClassification": "Classify the sentiment of a given tweet as either positive, negative, or neutral",
|
| 251 |
# C-MTEB eval instructions
|
| 252 |
+
"TNews": "Classify the fine-grained category of the given news title",
|
| 253 |
+
"IFlyTek": "Given an App description text, find the appropriate fine-grained category",
|
| 254 |
+
"MultilingualSentiment": "Classify sentiment of the customer review into positive, neutral, or negative",
|
| 255 |
+
"JDReview": "Classify the customer review for iPhone on e-commerce platform into positive or negative",
|
| 256 |
+
"OnlineShopping": "Classify the customer review for online shopping into positive or negative",
|
| 257 |
+
"Waimai": "Classify the customer review from a food takeaway platform into positive or negative",
|
| 258 |
# MTEB-pl eval instructions
|
| 259 |
+
"CBD": "Classify the sentiment of polish tweet reviews",
|
| 260 |
"PolEmo2.0-IN": "Classify the sentiment of in-domain (medicine and hotels) online reviews",
|
| 261 |
+
"PolEmo2.0-OUT": "Classify the sentiment of out-of-domain (products and school) online reviews",
|
| 262 |
"AllegroReviews": "Classify the sentiment of reviews from e-commerce marketplace Allegro",
|
| 263 |
+
"PAC": 'Classify the sentence into one of the two types: "BEZPIECZNE_POSTANOWIENIE_UMOWNE" and "KLAUZULA_ABUZYWNA"',
|
|
|
|
| 264 |
}
|
| 265 |
return task_name_to_instruct[task_name]
|
| 266 |
|
| 267 |
+
if task_type in ["Clustering"]:
|
| 268 |
task_name_to_instruct: Dict[str, str] = {
|
| 269 |
+
"ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts",
|
| 270 |
+
"ArxivClusteringS2S": "Identify the main and secondary category of Arxiv papers based on the titles",
|
| 271 |
+
"BiorxivClusteringP2P": "Identify the main category of Biorxiv papers based on the titles and abstracts",
|
| 272 |
+
"BiorxivClusteringS2S": "Identify the main category of Biorxiv papers based on the titles",
|
| 273 |
+
"MedrxivClusteringP2P": "Identify the main category of Medrxiv papers based on the titles and abstracts",
|
| 274 |
+
"MedrxivClusteringS2S": "Identify the main category of Medrxiv papers based on the titles",
|
| 275 |
+
"RedditClustering": "Identify the topic or theme of Reddit posts based on the titles",
|
| 276 |
+
"RedditClusteringP2P": "Identify the topic or theme of Reddit posts based on the titles and posts",
|
| 277 |
+
"StackExchangeClustering": "Identify the topic or theme of StackExchange posts based on the titles",
|
| 278 |
+
"StackExchangeClusteringP2P": "Identify the topic or theme of StackExchange posts based on the given paragraphs",
|
| 279 |
+
"TwentyNewsgroupsClustering": "Identify the topic or theme of the given news articles",
|
| 280 |
# C-MTEB eval instructions
|
| 281 |
+
"CLSClusteringS2S": "Identify the main category of scholar papers based on the titles",
|
| 282 |
+
"CLSClusteringP2P": "Identify the main category of scholar papers based on the titles and abstracts",
|
| 283 |
+
"ThuNewsClusteringS2S": "Identify the topic or theme of the given news articles based on the titles",
|
| 284 |
+
"ThuNewsClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents",
|
| 285 |
# MTEB-fr eval instructions
|
| 286 |
"AlloProfClusteringP2P": "Identify the main category of Allo Prof document based on the titles and descriptions",
|
| 287 |
"AlloProfClusteringS2S": "Identify the main category of Allo Prof document based on the titles",
|
|
|
|
| 289 |
"MasakhaNEWSClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents",
|
| 290 |
"MasakhaNEWSClusteringS2S": "Identify the topic or theme of the given news articles based on the titles",
|
| 291 |
"MLSUMClusteringP2P": "Identify the topic or theme of the given articles based on the titles and contents",
|
| 292 |
+
"MLSUMClusteringS2S": "Identify the topic or theme of the given articles based on the titles",
|
| 293 |
# MTEB-pl eval instructions
|
| 294 |
"8TagsClustering": "Identify of headlines from social media posts in Polish into 8 categories: film, history, food, medicine, motorization, work, sport and technology",
|
| 295 |
}
|
| 296 |
return task_name_to_instruct[task_name]
|
| 297 |
|
| 298 |
+
if task_type in ["Reranking", "PairClassification"]:
|
| 299 |
task_name_to_instruct: Dict[str, str] = {
|
| 300 |
+
"AskUbuntuDupQuestions": "Retrieve duplicate questions from AskUbuntu forum",
|
| 301 |
+
"MindSmallReranking": "Retrieve relevant news articles based on user browsing history",
|
| 302 |
+
"SciDocsRR": "Given a title of a scientific paper, retrieve the titles of other relevant papers",
|
| 303 |
+
"StackOverflowDupQuestions": "Retrieve duplicate questions from StackOverflow forum",
|
| 304 |
+
"SprintDuplicateQuestions": "Retrieve duplicate questions from Sprint forum",
|
| 305 |
+
"TwitterSemEval2015": "Retrieve tweets that are semantically similar to the given tweet",
|
| 306 |
+
"TwitterURLCorpus": "Retrieve tweets that are semantically similar to the given tweet",
|
| 307 |
# C-MTEB eval instructions
|
| 308 |
+
"T2Reranking": "Given a Chinese search query, retrieve web passages that answer the question",
|
| 309 |
+
"MmarcoReranking": "Given a Chinese search query, retrieve web passages that answer the question",
|
| 310 |
+
"CMedQAv1": "Given a Chinese community medical question, retrieve replies that best answer the question",
|
| 311 |
+
"CMedQAv2": "Given a Chinese community medical question, retrieve replies that best answer the question",
|
| 312 |
+
"Ocnli": "Retrieve semantically similar text.",
|
| 313 |
+
"Cmnli": "Retrieve semantically similar text.",
|
| 314 |
# MTEB-fr eval instructions
|
| 315 |
"AlloprofReranking": "Given a question, retrieve passages that answer the question",
|
| 316 |
+
"OpusparcusPC": "Retrieve semantically similar text",
|
| 317 |
+
"PawsX": "Retrieve semantically similar text",
|
| 318 |
"SyntecReranking": "Given a question, retrieve passages that answer the question",
|
| 319 |
# MTEB-pl eval instructions
|
| 320 |
"SICK-E-PL": "Retrieve semantically similar text",
|
|
|
|
| 324 |
}
|
| 325 |
return task_name_to_instruct[task_name]
|
| 326 |
|
| 327 |
+
if task_type in ["Retrieval"]:
|
| 328 |
+
if task_name.lower().startswith("cqadupstack"):
|
| 329 |
+
return "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question"
|
| 330 |
|
| 331 |
task_name_to_instruct: Dict[str, str] = {
|
| 332 |
+
"ArguAna": "Given a claim, find documents that refute the claim",
|
| 333 |
+
"ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim",
|
| 334 |
+
"DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia",
|
| 335 |
+
"FEVER": "Given a claim, retrieve documents that support or refute the claim",
|
| 336 |
+
"FiQA2018": "Given a financial question, retrieve user replies that best answer the question",
|
| 337 |
+
"HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question",
|
| 338 |
+
"MSMARCO": "Given a web search query, retrieve relevant passages that answer the query",
|
| 339 |
+
"NFCorpus": "Given a question, retrieve relevant documents that best answer the question",
|
| 340 |
+
"NQ": "Given a question, retrieve Wikipedia passages that answer the question",
|
| 341 |
+
"QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question",
|
| 342 |
+
"SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper",
|
| 343 |
+
"SciFact": "Given a scientific claim, retrieve documents that support or refute the claim",
|
| 344 |
+
"Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question",
|
| 345 |
+
"TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query",
|
| 346 |
# C-MTEB eval instructions
|
| 347 |
+
"T2Retrieval": "Given a Chinese search query, retrieve web passages that answer the question",
|
| 348 |
+
"MMarcoRetrieval": "Given a web search query, retrieve relevant passages that answer the query",
|
| 349 |
+
"DuRetrieval": "Given a Chinese search query, retrieve web passages that answer the question",
|
| 350 |
+
"CovidRetrieval": "Given a question on COVID-19, retrieve news articles that answer the question",
|
| 351 |
+
"CmedqaRetrieval": "Given a Chinese community medical question, retrieve replies that best answer the question",
|
| 352 |
+
"EcomRetrieval": "Given a user query from an e-commerce website, retrieve description sentences of relevant products",
|
| 353 |
+
"MedicalRetrieval": "Given a medical question, retrieve user replies that best answer the question",
|
| 354 |
+
"VideoRetrieval": "Given a video search query, retrieve the titles of relevant videos",
|
| 355 |
# MTEB-fr eval instructions
|
| 356 |
"AlloprofRetrieval": "Given a question, retrieve passages that answer the question",
|
| 357 |
"BSARDRetrieval": "Given a question, retrieve passages that answer the question",
|
| 358 |
"SyntecRetrieval": "Given a question, retrieve passages that answer the question",
|
| 359 |
"XPQARetrieval": "Given a question, retrieve passages that answer the question",
|
| 360 |
"MintakaRetrieval": "Given a question, retrieve passages that answer the question",
|
| 361 |
+
# MTEB-pl eval instructions
|
| 362 |
"ArguAna-PL": "Given a claim, find documents that refute the claim",
|
| 363 |
"DBPedia-PL": "Given a query, retrieve relevant entity descriptions from DBPedia",
|
| 364 |
"FiQA-PL": "Given a financial question, retrieve user replies that best answer the question",
|
|
|
|
| 369 |
"Quora-PL": "Given a question, retrieve questions that are semantically equivalent to the given question",
|
| 370 |
"SCIDOCS-PL": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper",
|
| 371 |
"SciFact-PL": "Given a scientific claim, retrieve documents that support or refute the claim",
|
| 372 |
+
"TRECCOVID-PL": "Given a query on COVID-19, retrieve documents that answer the query",
|
| 373 |
}
|
| 374 |
|
| 375 |
# add lower case keys to match some beir names
|
| 376 |
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
|
| 377 |
# other cases where lower case match still doesn't work
|
| 378 |
+
task_name_to_instruct["trec-covid"] = task_name_to_instruct["TRECCOVID"]
|
| 379 |
+
task_name_to_instruct["climate-fever"] = task_name_to_instruct["ClimateFEVER"]
|
| 380 |
+
task_name_to_instruct["dbpedia-entity"] = task_name_to_instruct["DBPedia"]
|
| 381 |
+
task_name_to_instruct["webis-touche2020"] = task_name_to_instruct["Touche2020"]
|
| 382 |
+
task_name_to_instruct["fiqa"] = task_name_to_instruct["FiQA2018"]
|
| 383 |
+
task_name_to_instruct["quora"] = task_name_to_instruct["QuoraRetrieval"]
|
| 384 |
|
| 385 |
# for miracl evaluation
|
| 386 |
+
task_name_to_instruct["miracl"] = (
|
| 387 |
+
"Given a question, retrieve Wikipedia passages that answer the question"
|
| 388 |
+
)
|
| 389 |
|
| 390 |
return task_name_to_instruct[task_name]
|
| 391 |
+
logging.warning(
|
| 392 |
+
f"No instruction config for task {task_name} with type {task_type}, use default instruction."
|
| 393 |
+
)
|
| 394 |
+
return default_instruct
|
| 395 |
+
|
| 396 |
|
| 397 |
class Encoder(torch.nn.Module):
|
| 398 |
+
def __init__(self, name_or_path: str, pooling: str):
|
| 399 |
super().__init__()
|
| 400 |
self.model = AutoModel.from_pretrained(name_or_path, trust_remote_code=True)
|
| 401 |
self.model = self.model.half()
|
| 402 |
+
self.model.eval()
|
| 403 |
self.pooling = pooling
|
| 404 |
|
| 405 |
def forward(self, **features) -> torch.Tensor:
|
| 406 |
output = self.model(**features, output_hidden_states=True, return_dict=True)
|
| 407 |
+
hidden_state = output.hidden_states[-1]
|
| 408 |
embeddings = self.pooler(hidden_state, **features)
|
| 409 |
return embeddings
|
| 410 |
|
| 411 |
def pooler(
|
| 412 |
+
self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, **kwargs
|
|
|
|
|
|
|
|
|
|
| 413 |
) -> torch.Tensor:
|
| 414 |
if attention_mask.ndim == 2:
|
| 415 |
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size())
|
|
|
|
| 420 |
|
| 421 |
hidden_state = hidden_state * mask_expanded
|
| 422 |
|
| 423 |
+
if self.pooling == "first":
|
| 424 |
pooled_output = hidden_state[:, 0]
|
| 425 |
|
| 426 |
+
elif self.pooling == "last":
|
| 427 |
+
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
|
| 428 |
if left_padding:
|
| 429 |
return hidden_state[:, -1]
|
| 430 |
else:
|
| 431 |
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 432 |
batch_size = hidden_state.shape[0]
|
| 433 |
+
return hidden_state[
|
| 434 |
+
torch.arange(batch_size, device=hidden_state.device), sequence_lengths
|
| 435 |
+
]
|
| 436 |
+
elif self.pooling == "mean":
|
| 437 |
# TODO: weight
|
| 438 |
lengths = mask_expanded.sum(1).clamp(min=1e-9)
|
| 439 |
pooled_output = hidden_state.sum(dim=1) / lengths
|
| 440 |
|
| 441 |
+
elif self.pooling == "weightedmean":
|
| 442 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
|
| 443 |
# hidden_state shape: bs, seq, hidden_dim
|
| 444 |
weights = (
|
| 445 |
+
torch.arange(start=1, end=hidden_state.shape[1] + 1)
|
| 446 |
+
.unsqueeze(0)
|
| 447 |
+
.unsqueeze(-1)
|
| 448 |
+
.expand(hidden_state.size())
|
| 449 |
+
.float()
|
| 450 |
+
.to(hidden_state.device)
|
| 451 |
+
)
|
| 452 |
assert weights.shape == hidden_state.shape == input_mask_expanded.shape
|
| 453 |
input_mask_expanded = input_mask_expanded * weights
|
| 454 |
|
|
|
|
| 474 |
force_default: bool = False,
|
| 475 |
sep: str = " ",
|
| 476 |
mp_tensor_to_cuda: bool = False,
|
| 477 |
+
instruction: Optional[str] = None,
|
|
|
|
| 478 |
):
|
| 479 |
self.tokenizer = tokenizer
|
| 480 |
self.model = encoder
|
| 481 |
self.batch_size = batch_size
|
| 482 |
self.max_seq_len = max_seq_len
|
| 483 |
+
self.pool: Optional[dict] = None
|
| 484 |
self.normalize_embeddings = normalize_embeddings
|
| 485 |
self.mp_tensor_to_cuda = mp_tensor_to_cuda
|
| 486 |
self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 487 |
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
| 488 |
self.instruction = instruction
|
| 489 |
+
self.default_query = default_query
|
| 490 |
self.sep = sep
|
| 491 |
self.force_default = force_default
|
| 492 |
+
if self.tokenizer.padding_side != "right":
|
| 493 |
+
logger.warning(
|
| 494 |
+
f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right"
|
| 495 |
+
)
|
| 496 |
+
self.tokenizer.padding_side = "right"
|
| 497 |
if self.tokenizer.pad_token is None:
|
| 498 |
logger.warning(f"Set tokenizer.pad_token as eos_token {self.tokenizer.eos_token}")
|
| 499 |
+
self.tokenizer.pad_token = "<|endoftext|>"
|
| 500 |
|
| 501 |
def start(self, target_devices: Optional[List[str]] = None):
|
| 502 |
"""
|
|
|
|
| 509 |
"""
|
| 510 |
if target_devices is None:
|
| 511 |
if torch.cuda.is_available():
|
| 512 |
+
target_devices = ["cuda:{}".format(i) for i in range(torch.cuda.device_count())]
|
| 513 |
else:
|
| 514 |
logger.info("CUDA is not available. Start 4 CPU worker")
|
| 515 |
+
target_devices = ["cpu"] * 4
|
| 516 |
|
| 517 |
+
logger.info(
|
| 518 |
+
"Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices)))
|
| 519 |
+
)
|
| 520 |
+
print("multi instruction", self.instruction)
|
| 521 |
+
ctx = mp.get_context("spawn")
|
| 522 |
input_queue = ctx.Queue()
|
| 523 |
output_queue = ctx.Queue()
|
| 524 |
processes = []
|
|
|
|
| 527 |
p = ctx.Process(
|
| 528 |
target=self._encode_multi_process_worker,
|
| 529 |
args=(cuda_id, self, input_queue, output_queue),
|
| 530 |
+
daemon=True,
|
| 531 |
)
|
| 532 |
p.start()
|
| 533 |
processes.append(p)
|
| 534 |
|
| 535 |
+
self.pool = {"input": input_queue, "output": output_queue, "processes": processes}
|
| 536 |
|
| 537 |
def stop(self):
|
| 538 |
"""
|
| 539 |
Stops all processes started with start_multi_process_pool
|
| 540 |
"""
|
| 541 |
+
for p in self.pool["processes"]:
|
| 542 |
p.terminate()
|
| 543 |
|
| 544 |
+
for p in self.pool["processes"]:
|
| 545 |
p.join()
|
| 546 |
p.close()
|
| 547 |
|
| 548 |
+
self.pool["input"].close()
|
| 549 |
+
self.pool["output"].close()
|
| 550 |
|
| 551 |
@staticmethod
|
| 552 |
def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue):
|
|
|
|
| 562 |
except queue.Empty:
|
| 563 |
break
|
| 564 |
|
| 565 |
+
def encode_multi_process(self, sentences: List[str], **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
"""
|
| 567 |
This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
|
| 568 |
and sent to individual processes, which encode these on the different GPUs. This method is only suitable
|
|
|
|
| 577 |
part_size = math.ceil(len(sentences) / len(self.pool["processes"]))
|
| 578 |
chunk_size = part_size if part_size < 3200 else 3200 # for retrieval chunk 50000
|
| 579 |
|
| 580 |
+
logger.debug(
|
| 581 |
+
f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}"
|
| 582 |
+
)
|
| 583 |
|
| 584 |
+
input_queue = self.pool["input"]
|
| 585 |
last_chunk_id = 0
|
| 586 |
chunk = []
|
| 587 |
|
|
|
|
| 596 |
input_queue.put([last_chunk_id, chunk, kwargs])
|
| 597 |
last_chunk_id += 1
|
| 598 |
|
| 599 |
+
output_queue = self.pool["output"]
|
| 600 |
+
results_list = sorted(
|
| 601 |
+
[output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0]
|
| 602 |
+
)
|
| 603 |
embeddings = np.concatenate([result[1] for result in results_list])
|
| 604 |
return embeddings
|
| 605 |
|
|
|
|
| 620 |
(representing several text inputs to the model).
|
| 621 |
"""
|
| 622 |
|
| 623 |
+
if isinstance(text, dict): # {key: value} case
|
| 624 |
return len(next(iter(text.values())))
|
| 625 |
+
elif not hasattr(text, "__len__"): # Object has no len() method
|
| 626 |
return 1
|
| 627 |
+
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
|
| 628 |
return len(text)
|
| 629 |
else:
|
| 630 |
+
return sum([len(t) for t in text]) # Sum of length of individual strings
|
| 631 |
|
| 632 |
def _tokenize(self, sentences: List[str], is_query: bool):
|
| 633 |
+
batch_dict = self.tokenizer(
|
| 634 |
+
sentences,
|
| 635 |
+
max_length=self.max_seq_len - 1,
|
| 636 |
+
return_attention_mask=False,
|
| 637 |
+
padding=False,
|
| 638 |
+
truncation=True,
|
| 639 |
+
)
|
| 640 |
+
batch_dict["input_ids"] = [
|
| 641 |
+
input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict["input_ids"]
|
| 642 |
+
]
|
| 643 |
+
batch_dict = self.tokenizer.pad(
|
| 644 |
+
batch_dict, padding=True, return_attention_mask=True, return_tensors="pt"
|
| 645 |
+
)
|
| 646 |
+
batch_dict["is_causal"] = False
|
| 647 |
return batch_dict
|
| 648 |
|
|
|
|
| 649 |
def _encode(
|
| 650 |
self,
|
| 651 |
sentences: List[str],
|
| 652 |
is_query: bool,
|
| 653 |
convert_to_numpy: bool = True,
|
| 654 |
convert_to_tensor: bool = False,
|
| 655 |
+
device: Optional[str] = None,
|
| 656 |
show_progress_bar: bool = True,
|
| 657 |
+
**kwargs,
|
| 658 |
):
|
| 659 |
"""
|
| 660 |
Computes sentence embeddings
|
|
|
|
| 677 |
convert_to_numpy = False
|
| 678 |
|
| 679 |
input_was_string = False
|
| 680 |
+
if isinstance(sentences, str) or not hasattr(
|
| 681 |
+
sentences, "__len__"
|
| 682 |
+
): # Cast an individual sentence to a list with length 1
|
| 683 |
sentences = [sentences]
|
| 684 |
input_was_string = True
|
| 685 |
|
|
|
|
| 692 |
length_sorted_idx = np.argsort([-self._text_length(s) for s in sentences])
|
| 693 |
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
| 694 |
|
| 695 |
+
for start_index in trange(
|
| 696 |
+
0, len(sentences), self.batch_size, desc="Batches", disable=not show_progress_bar
|
| 697 |
+
):
|
| 698 |
+
sentences_batch = sentences_sorted[start_index : start_index + self.batch_size]
|
| 699 |
features = self._tokenize(sentences_batch, is_query)
|
| 700 |
features = self.batch_to_device(features, device)
|
| 701 |
|
|
|
|
| 716 |
if convert_to_tensor:
|
| 717 |
all_embeddings = torch.stack(all_embeddings)
|
| 718 |
elif convert_to_numpy:
|
| 719 |
+
# all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
| 720 |
all_embeddings = np.asarray([emb.to(torch.float).numpy() for emb in all_embeddings])
|
| 721 |
if input_was_string:
|
| 722 |
all_embeddings = all_embeddings[0]
|
|
|
|
| 728 |
sentences: List[str],
|
| 729 |
is_query: Optional[bool] = None,
|
| 730 |
convert_to_tensor: bool = False,
|
| 731 |
+
**kwargs,
|
| 732 |
):
|
| 733 |
is_query = self.default_query if is_query is None else is_query
|
| 734 |
if is_query and self.instruction:
|
| 735 |
+
sentences = [self.instruction + sent for sent in sentences]
|
| 736 |
kwargs.update(is_query=is_query)
|
| 737 |
if self.pool is not None:
|
| 738 |
kwargs.update(show_progress_bar=False)
|
|
|
|
| 740 |
if convert_to_tensor:
|
| 741 |
embeddings = torch.from_numpy(embeddings)
|
| 742 |
if self.mp_tensor_to_cuda and torch.cuda.is_available():
|
| 743 |
+
embeddings = embeddings.to(torch.device("cuda")) # default 0-th gpu
|
| 744 |
return embeddings
|
| 745 |
|
| 746 |
return self._encode(sentences, convert_to_tensor=convert_to_tensor, **kwargs)
|
|
|
|
| 760 |
]
|
| 761 |
elif isinstance(corpus[0], dict):
|
| 762 |
sentences = [
|
| 763 |
+
(doc["title"] + self.sep + doc["text"]).strip()
|
| 764 |
+
if "title" in doc
|
| 765 |
+
else doc["text"].strip()
|
| 766 |
for doc in corpus
|
| 767 |
]
|
| 768 |
else:
|
|
|
|
| 770 |
is_query = self.default_query if self.force_default else False
|
| 771 |
return self.encode(sentences, is_query=is_query, **kwargs)
|
| 772 |
|
| 773 |
+
|
| 774 |
def main(args):
|
| 775 |
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 776 |
encoder = Encoder(args.model, args.pooling)
|
| 777 |
+
default_query = args.default_type == "query"
|
| 778 |
model = Wrapper(
|
| 779 |
+
tokenizer,
|
| 780 |
+
encoder,
|
| 781 |
batch_size=args.batch_size,
|
| 782 |
max_seq_len=args.max_seq_len,
|
| 783 |
normalize_embeddings=args.norm,
|
| 784 |
+
default_query=default_query,
|
| 785 |
)
|
| 786 |
+
sym_retrievals = ["QuoraRetrieval", "ArguAna", "CQADupstack"]
|
| 787 |
+
if args.task == "mteb":
|
| 788 |
task_names = MTEB_TASK_LIST
|
| 789 |
+
lang = ["en"]
|
| 790 |
+
elif args.task == "cmteb":
|
| 791 |
task_names = CMTEB_TASK_LIST
|
| 792 |
+
lang = ["zh", "zh-CN"]
|
| 793 |
+
elif args.task == "mteb-fr":
|
| 794 |
+
task_names = MTEB_FR
|
| 795 |
+
lang = ["fr"]
|
| 796 |
+
elif args.task == "mteb-pl":
|
| 797 |
+
task_names = MTEB_PL
|
| 798 |
+
lang = ["pl"]
|
| 799 |
else:
|
| 800 |
task_names = [args.task]
|
| 801 |
+
lang = ["en", "zh", "zh-CN", "pl", "fr"]
|
| 802 |
for task in task_names:
|
| 803 |
evaluation = MTEB(tasks=[task], task_langs=lang)
|
| 804 |
task_cls = evaluation.tasks[0]
|
| 805 |
+
task_name: str = task_cls.metadata_dict["name"]
|
| 806 |
+
task_type: str = task_cls.metadata_dict["type"]
|
| 807 |
instruction = get_task_def_by_task_name_and_type(task_name, task_type)
|
| 808 |
model.instruction = get_detailed_instruct(instruction)
|
| 809 |
+
if task == "MSMARCO":
|
| 810 |
eval_splits = ["dev"]
|
| 811 |
elif task in CMTEB_TASK_LIST:
|
| 812 |
+
eval_splits = task_cls.metadata_dict["eval_splits"]
|
| 813 |
else:
|
| 814 |
eval_splits = ["test"]
|
| 815 |
sym = False
|
|
|
|
| 820 |
else:
|
| 821 |
sym = False
|
| 822 |
if sym:
|
| 823 |
+
logger.info(
|
| 824 |
+
f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}."
|
| 825 |
+
)
|
| 826 |
model.force_default = True
|
| 827 |
evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
|
| 828 |
|
| 829 |
if sym:
|
| 830 |
logger.info(f"Switch back.")
|
| 831 |
model.force_default = force_default_ori
|
| 832 |
+
print("\n")
|
| 833 |
|
| 834 |
|
| 835 |
if __name__ == "__main__":
|
| 836 |
_PARSER = argparse.ArgumentParser()
|
| 837 |
+
_PARSER.add_argument("-m", "--model", type=str, default=None)
|
| 838 |
+
_PARSER.add_argument("--pooling", type=str, default="last")
|
|
|
|
|
|
|
| 839 |
_PARSER.add_argument("--output_dir", type=str, default=None)
|
| 840 |
+
_PARSER.add_argument("--default_type", type=str, default="query")
|
| 841 |
_PARSER.add_argument("--max_seq_len", type=int, default=512)
|
| 842 |
_PARSER.add_argument("-b", "--batch_size", type=int, default=32)
|
| 843 |
_PARSER.add_argument(
|
| 844 |
+
"-t",
|
| 845 |
+
"--task",
|
| 846 |
+
type=str,
|
| 847 |
+
default=None, # None for running default tasks
|
| 848 |
)
|
| 849 |
_PARSER.add_argument("--norm", action="store_true")
|
| 850 |
_ARGS = _PARSER.parse_args()
|