Roman Solomatin
commited on
update after review
Browse files- listconranker.py +101 -4
listconranker.py
CHANGED
@@ -30,6 +30,9 @@ from transformers import (
|
|
30 |
import os
|
31 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
32 |
from typing import Union, List, Optional
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
class ListConRankerConfig(BertConfig):
|
@@ -295,14 +298,15 @@ class ListConRankerModel(PreTrainedModel):
|
|
295 |
if sep_idxs.numel() == 0:
|
296 |
raise ValueError(f"No SEP in sequence {idx}")
|
297 |
first_sep = sep_idxs[0].item()
|
|
|
298 |
|
299 |
# Extract query and passage
|
300 |
q_seq = seq[: first_sep + 1]
|
301 |
q_mask = mask[: first_sep + 1]
|
302 |
q_tt = torch.zeros_like(q_seq)
|
303 |
|
304 |
-
p_seq = seq[first_sep:]
|
305 |
-
p_mask = mask[first_sep:]
|
306 |
p_seq = p_seq.clone()
|
307 |
p_seq[0] = self.config.cls_token_id
|
308 |
p_tt = torch.zeros_like(p_seq)
|
@@ -315,6 +319,16 @@ class ListConRankerModel(PreTrainedModel):
|
|
315 |
].tolist()
|
316 |
)
|
317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
if key not in grouped:
|
319 |
grouped[key] = {
|
320 |
"query": (q_seq, q_mask, q_tt),
|
@@ -396,7 +410,7 @@ class ListConRankerModel(PreTrainedModel):
|
|
396 |
):
|
397 |
model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
|
398 |
model.hf_model = BertModel.from_pretrained(
|
399 |
-
model_name_or_path, config=model.config.bert_config
|
400 |
)
|
401 |
|
402 |
linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
|
@@ -439,11 +453,94 @@ class ListConRankerModel(PreTrainedModel):
|
|
439 |
inputs = tokenizer(
|
440 |
batch_pairs,
|
441 |
padding=True,
|
442 |
-
truncation=
|
443 |
return_tensors="pt",
|
444 |
)
|
|
|
|
|
|
|
|
|
445 |
logits = self(**inputs)[0]
|
446 |
total_logits[batch * batch_size : (batch + 1) * batch_size] = (
|
447 |
logits.squeeze(1)
|
448 |
)
|
449 |
return total_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
import os
|
31 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
32 |
from typing import Union, List, Optional
|
33 |
+
from collections import defaultdict
|
34 |
+
import numpy as np
|
35 |
+
import math
|
36 |
|
37 |
|
38 |
class ListConRankerConfig(BertConfig):
|
|
|
298 |
if sep_idxs.numel() == 0:
|
299 |
raise ValueError(f"No SEP in sequence {idx}")
|
300 |
first_sep = sep_idxs[0].item()
|
301 |
+
second_sep = sep_idxs[1].item()
|
302 |
|
303 |
# Extract query and passage
|
304 |
q_seq = seq[: first_sep + 1]
|
305 |
q_mask = mask[: first_sep + 1]
|
306 |
q_tt = torch.zeros_like(q_seq)
|
307 |
|
308 |
+
p_seq = seq[first_sep : second_sep + 1]
|
309 |
+
p_mask = mask[first_sep : second_sep + 1]
|
310 |
p_seq = p_seq.clone()
|
311 |
p_seq[0] = self.config.cls_token_id
|
312 |
p_tt = torch.zeros_like(p_seq)
|
|
|
319 |
].tolist()
|
320 |
)
|
321 |
|
322 |
+
# truncation
|
323 |
+
q_seq = q_seq[: self.config.max_position_embeddings]
|
324 |
+
q_seq[-1] = self.config.sep_token_id
|
325 |
+
p_seq = p_seq[: self.config.max_position_embeddings]
|
326 |
+
p_seq[-1] = self.config.sep_token_id
|
327 |
+
q_mask = q_mask[: self.config.max_position_embeddings]
|
328 |
+
p_mask = p_mask[: self.config.max_position_embeddings]
|
329 |
+
q_tt = q_tt[: self.config.max_position_embeddings]
|
330 |
+
p_tt = p_tt[: self.config.max_position_embeddings]
|
331 |
+
|
332 |
if key not in grouped:
|
333 |
grouped[key] = {
|
334 |
"query": (q_seq, q_mask, q_tt),
|
|
|
410 |
):
|
411 |
model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
|
412 |
model.hf_model = BertModel.from_pretrained(
|
413 |
+
model_name_or_path, config=model.config.bert_config, **kwargs
|
414 |
)
|
415 |
|
416 |
linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
|
|
|
453 |
inputs = tokenizer(
|
454 |
batch_pairs,
|
455 |
padding=True,
|
456 |
+
truncation=False,
|
457 |
return_tensors="pt",
|
458 |
)
|
459 |
+
|
460 |
+
for k, v in inputs.items():
|
461 |
+
inputs[k] = v.to(self.device)
|
462 |
+
|
463 |
logits = self(**inputs)[0]
|
464 |
total_logits[batch * batch_size : (batch + 1) * batch_size] = (
|
465 |
logits.squeeze(1)
|
466 |
)
|
467 |
return total_logits
|
468 |
+
|
469 |
+
def multi_passage_in_iterative_inference(
|
470 |
+
self,
|
471 |
+
sentences: List[str],
|
472 |
+
stop_num: int = 20,
|
473 |
+
decrement_rate: float = 0.2,
|
474 |
+
min_filter_num: int = 10,
|
475 |
+
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
|
476 |
+
"ByteDance/ListConRanker"
|
477 |
+
),
|
478 |
+
):
|
479 |
+
"""
|
480 |
+
Process multiple passages for one query in iterative inference.
|
481 |
+
:param sentences: List contains sentences for a query.
|
482 |
+
:return: Tensor of logits for each passage.
|
483 |
+
"""
|
484 |
+
if stop_num < 1:
|
485 |
+
raise ValueError("stop_num must be greater than 0")
|
486 |
+
if decrement_rate <= 0 or decrement_rate >= 1:
|
487 |
+
raise ValueError("decrement_rate must be in (0, 1)")
|
488 |
+
if min_filter_num < 1:
|
489 |
+
raise ValueError("min_filter_num must be greater than 0")
|
490 |
+
|
491 |
+
query = sentences[0]
|
492 |
+
passage = sentences[1:]
|
493 |
+
|
494 |
+
filter_times = 0
|
495 |
+
passage2score = defaultdict(list)
|
496 |
+
while len(passage) > stop_num:
|
497 |
+
batch = [[query] + passage]
|
498 |
+
pred_scores = self.multi_passage(
|
499 |
+
batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
|
500 |
+
).tolist()
|
501 |
+
pred_scores_argsort = np.argsort(
|
502 |
+
pred_scores
|
503 |
+
).tolist() # Sort in increasing order
|
504 |
+
|
505 |
+
passage_len = len(passage)
|
506 |
+
to_filter_num = math.ceil(passage_len * decrement_rate)
|
507 |
+
if to_filter_num < min_filter_num:
|
508 |
+
to_filter_num = min_filter_num
|
509 |
+
|
510 |
+
have_filter_num = 0
|
511 |
+
while have_filter_num < to_filter_num:
|
512 |
+
idx = pred_scores_argsort[have_filter_num]
|
513 |
+
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
514 |
+
have_filter_num += 1
|
515 |
+
while (
|
516 |
+
pred_scores[pred_scores_argsort[have_filter_num - 1]]
|
517 |
+
== pred_scores[pred_scores_argsort[have_filter_num]]
|
518 |
+
):
|
519 |
+
idx = pred_scores_argsort[have_filter_num]
|
520 |
+
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
521 |
+
have_filter_num += 1
|
522 |
+
next_passage = []
|
523 |
+
next_passage_idx = have_filter_num
|
524 |
+
while next_passage_idx < len(passage):
|
525 |
+
idx = pred_scores_argsort[next_passage_idx]
|
526 |
+
next_passage.append(passage[idx])
|
527 |
+
next_passage_idx += 1
|
528 |
+
passage = next_passage
|
529 |
+
filter_times += 1
|
530 |
+
|
531 |
+
batch = [[query] + passage]
|
532 |
+
pred_scores = self.multi_passage(
|
533 |
+
batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
|
534 |
+
).tolist()
|
535 |
+
|
536 |
+
cnt = 0
|
537 |
+
while cnt < len(passage):
|
538 |
+
passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
|
539 |
+
cnt += 1
|
540 |
+
|
541 |
+
passage = sentences[1:]
|
542 |
+
final_score = []
|
543 |
+
for i in range(len(passage)):
|
544 |
+
p = passage[i]
|
545 |
+
final_score.append(passage2score[p][0])
|
546 |
+
return final_score
|