Roman Solomatin commited on
Commit
1d0f0ea
·
unverified ·
1 Parent(s): dc62085

update after review

Browse files
Files changed (1) hide show
  1. listconranker.py +101 -4
listconranker.py CHANGED
@@ -30,6 +30,9 @@ from transformers import (
30
  import os
31
  from transformers.modeling_outputs import SequenceClassifierOutput
32
  from typing import Union, List, Optional
 
 
 
33
 
34
 
35
  class ListConRankerConfig(BertConfig):
@@ -295,14 +298,15 @@ class ListConRankerModel(PreTrainedModel):
295
  if sep_idxs.numel() == 0:
296
  raise ValueError(f"No SEP in sequence {idx}")
297
  first_sep = sep_idxs[0].item()
 
298
 
299
  # Extract query and passage
300
  q_seq = seq[: first_sep + 1]
301
  q_mask = mask[: first_sep + 1]
302
  q_tt = torch.zeros_like(q_seq)
303
 
304
- p_seq = seq[first_sep:]
305
- p_mask = mask[first_sep:]
306
  p_seq = p_seq.clone()
307
  p_seq[0] = self.config.cls_token_id
308
  p_tt = torch.zeros_like(p_seq)
@@ -315,6 +319,16 @@ class ListConRankerModel(PreTrainedModel):
315
  ].tolist()
316
  )
317
 
 
 
 
 
 
 
 
 
 
 
318
  if key not in grouped:
319
  grouped[key] = {
320
  "query": (q_seq, q_mask, q_tt),
@@ -396,7 +410,7 @@ class ListConRankerModel(PreTrainedModel):
396
  ):
397
  model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
398
  model.hf_model = BertModel.from_pretrained(
399
- model_name_or_path, config=model.config.bert_config
400
  )
401
 
402
  linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
@@ -439,11 +453,94 @@ class ListConRankerModel(PreTrainedModel):
439
  inputs = tokenizer(
440
  batch_pairs,
441
  padding=True,
442
- truncation=True,
443
  return_tensors="pt",
444
  )
 
 
 
 
445
  logits = self(**inputs)[0]
446
  total_logits[batch * batch_size : (batch + 1) * batch_size] = (
447
  logits.squeeze(1)
448
  )
449
  return total_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  import os
31
  from transformers.modeling_outputs import SequenceClassifierOutput
32
  from typing import Union, List, Optional
33
+ from collections import defaultdict
34
+ import numpy as np
35
+ import math
36
 
37
 
38
  class ListConRankerConfig(BertConfig):
 
298
  if sep_idxs.numel() == 0:
299
  raise ValueError(f"No SEP in sequence {idx}")
300
  first_sep = sep_idxs[0].item()
301
+ second_sep = sep_idxs[1].item()
302
 
303
  # Extract query and passage
304
  q_seq = seq[: first_sep + 1]
305
  q_mask = mask[: first_sep + 1]
306
  q_tt = torch.zeros_like(q_seq)
307
 
308
+ p_seq = seq[first_sep : second_sep + 1]
309
+ p_mask = mask[first_sep : second_sep + 1]
310
  p_seq = p_seq.clone()
311
  p_seq[0] = self.config.cls_token_id
312
  p_tt = torch.zeros_like(p_seq)
 
319
  ].tolist()
320
  )
321
 
322
+ # truncation
323
+ q_seq = q_seq[: self.config.max_position_embeddings]
324
+ q_seq[-1] = self.config.sep_token_id
325
+ p_seq = p_seq[: self.config.max_position_embeddings]
326
+ p_seq[-1] = self.config.sep_token_id
327
+ q_mask = q_mask[: self.config.max_position_embeddings]
328
+ p_mask = p_mask[: self.config.max_position_embeddings]
329
+ q_tt = q_tt[: self.config.max_position_embeddings]
330
+ p_tt = p_tt[: self.config.max_position_embeddings]
331
+
332
  if key not in grouped:
333
  grouped[key] = {
334
  "query": (q_seq, q_mask, q_tt),
 
410
  ):
411
  model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
412
  model.hf_model = BertModel.from_pretrained(
413
+ model_name_or_path, config=model.config.bert_config, **kwargs
414
  )
415
 
416
  linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
 
453
  inputs = tokenizer(
454
  batch_pairs,
455
  padding=True,
456
+ truncation=False,
457
  return_tensors="pt",
458
  )
459
+
460
+ for k, v in inputs.items():
461
+ inputs[k] = v.to(self.device)
462
+
463
  logits = self(**inputs)[0]
464
  total_logits[batch * batch_size : (batch + 1) * batch_size] = (
465
  logits.squeeze(1)
466
  )
467
  return total_logits
468
+
469
+ def multi_passage_in_iterative_inference(
470
+ self,
471
+ sentences: List[str],
472
+ stop_num: int = 20,
473
+ decrement_rate: float = 0.2,
474
+ min_filter_num: int = 10,
475
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
476
+ "ByteDance/ListConRanker"
477
+ ),
478
+ ):
479
+ """
480
+ Process multiple passages for one query in iterative inference.
481
+ :param sentences: List contains sentences for a query.
482
+ :return: Tensor of logits for each passage.
483
+ """
484
+ if stop_num < 1:
485
+ raise ValueError("stop_num must be greater than 0")
486
+ if decrement_rate <= 0 or decrement_rate >= 1:
487
+ raise ValueError("decrement_rate must be in (0, 1)")
488
+ if min_filter_num < 1:
489
+ raise ValueError("min_filter_num must be greater than 0")
490
+
491
+ query = sentences[0]
492
+ passage = sentences[1:]
493
+
494
+ filter_times = 0
495
+ passage2score = defaultdict(list)
496
+ while len(passage) > stop_num:
497
+ batch = [[query] + passage]
498
+ pred_scores = self.multi_passage(
499
+ batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
500
+ ).tolist()
501
+ pred_scores_argsort = np.argsort(
502
+ pred_scores
503
+ ).tolist() # Sort in increasing order
504
+
505
+ passage_len = len(passage)
506
+ to_filter_num = math.ceil(passage_len * decrement_rate)
507
+ if to_filter_num < min_filter_num:
508
+ to_filter_num = min_filter_num
509
+
510
+ have_filter_num = 0
511
+ while have_filter_num < to_filter_num:
512
+ idx = pred_scores_argsort[have_filter_num]
513
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
514
+ have_filter_num += 1
515
+ while (
516
+ pred_scores[pred_scores_argsort[have_filter_num - 1]]
517
+ == pred_scores[pred_scores_argsort[have_filter_num]]
518
+ ):
519
+ idx = pred_scores_argsort[have_filter_num]
520
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
521
+ have_filter_num += 1
522
+ next_passage = []
523
+ next_passage_idx = have_filter_num
524
+ while next_passage_idx < len(passage):
525
+ idx = pred_scores_argsort[next_passage_idx]
526
+ next_passage.append(passage[idx])
527
+ next_passage_idx += 1
528
+ passage = next_passage
529
+ filter_times += 1
530
+
531
+ batch = [[query] + passage]
532
+ pred_scores = self.multi_passage(
533
+ batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
534
+ ).tolist()
535
+
536
+ cnt = 0
537
+ while cnt < len(passage):
538
+ passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
539
+ cnt += 1
540
+
541
+ passage = sentences[1:]
542
+ final_score = []
543
+ for i in range(len(passage)):
544
+ p = passage[i]
545
+ final_score.append(passage2score[p][0])
546
+ return final_score