Integrate with Transformers & SentenceTransformers

#3
by Samoed - opened
Files changed (4) hide show
  1. README.md +50 -7
  2. config.json +9 -2
  3. listconranker.py +546 -0
  4. tokenizer_config.json +1 -1
README.md CHANGED
@@ -1,6 +1,9 @@
1
  ---
2
  tags:
3
  - mteb
 
 
 
4
  model-index:
5
  - name: ListConRanker
6
  results:
@@ -103,9 +106,9 @@ To reduce the discrepancy between training and inference, we propose iterative i
103
 
104
  ## How to use
105
  ```python
106
- from modules.listconranker import ListConRanker
107
 
108
- reranker = ListConRanker('./', use_fp16=True, list_transformer_layer=2)
109
 
110
  # [query, passages_1, passage_2, ..., passage_n]
111
  batch = [
@@ -125,15 +128,55 @@ batch = [
125
  ]
126
 
127
  # for conventional inference, please manage the batch size by yourself
128
- scores = reranker.compute_score(batch)
129
  print(scores)
130
  # [[0.5126953125, 0.331298828125, 0.3642578125], [0.63671875, 0.71630859375, 0.42822265625, 0.35302734375]]
131
 
132
- # for iterative inferfence, only a batch size of 1 is supported
133
- # the scores do not indicate similarity but are intended only for ranking
134
- scores = reranker.iterative_inference(batch[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  print(scores)
136
- # [0.5126953125, 0.331298828125, 0.3642578125]
137
  ```
138
 
139
  To reproduce the results with iterative inference, please run:
 
1
  ---
2
  tags:
3
  - mteb
4
+ - sentence-transformers
5
+ - transformers
6
+ pipeline_tag: text-ranking
7
  model-index:
8
  - name: ListConRanker
9
  results:
 
106
 
107
  ## How to use
108
  ```python
109
+ from transfoermers import AutoModelForSequenceClassification
110
 
111
+ reranker = AutoModelForSequenceClassification('ByteDance/ListConRanker', trust_remote_code=True)
112
 
113
  # [query, passages_1, passage_2, ..., passage_n]
114
  batch = [
 
128
  ]
129
 
130
  # for conventional inference, please manage the batch size by yourself
131
+ scores = reranker.multi_passage(batch)
132
  print(scores)
133
  # [[0.5126953125, 0.331298828125, 0.3642578125], [0.63671875, 0.71630859375, 0.42822265625, 0.35302734375]]
134
 
135
+ inputs = tokenizer(
136
+ [
137
+ [
138
+ "query 1",
139
+ "passage_11",
140
+ ],
141
+ [
142
+ "query 1",
143
+ "passage_12",
144
+ ],
145
+ [
146
+ "query_2",
147
+ "passage_21",
148
+ ],
149
+ ],
150
+ return_tensors="pt",
151
+ padding=True,
152
+ )
153
+ probs, logits = model(**inputs)
154
+ print(probs)
155
+ # tensor([[0.4359], [0.3840]], grad_fn=<ViewBackward0>)
156
+ ```
157
+ or using the `sentence_transformers` library:
158
+ ```python
159
+ from sentence_transformers import CrossEncoder
160
+
161
+ model = CrossEncoder('ByteDance/ListConRanker', trust_remote_code=True)
162
+
163
+ inputs = [
164
+ [
165
+ "query 1",
166
+ "passage_11",
167
+ ],
168
+ [
169
+ "query 1",
170
+ "passage_12",
171
+ ],
172
+ [
173
+ "query_2",
174
+ "passage_21",
175
+ ],
176
+ ]
177
+ scores = model.predict(inputs)
178
  print(scores)
179
+ # [0.43585014, 0.32305932, 0.38395187]
180
  ```
181
 
182
  To reproduce the results with iterative inference, please run:
config.json CHANGED
@@ -1,7 +1,11 @@
1
  {
2
  "architectures": [
3
- "BertModel"
4
  ],
 
 
 
 
5
  "attention_probs_dropout_prob": 0.1,
6
  "classifier_dropout": null,
7
  "directionality": "bidi",
@@ -9,6 +13,7 @@
9
  "hidden_act": "gelu",
10
  "hidden_dropout_prob": 0.1,
11
  "hidden_size": 1024,
 
12
  "id2label": {
13
  "0": "LABEL_0"
14
  },
@@ -34,5 +39,7 @@
34
  "transformers_version": "4.45.2",
35
  "type_vocab_size": 2,
36
  "use_cache": true,
37
- "vocab_size": 21128
 
 
38
  }
 
1
  {
2
  "architectures": [
3
+ "ListConRanker"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "listconranker.ListConRankerConfig",
7
+ "AutoModelForSequenceClassification": "listconranker.ListConRankerModel"
8
+ },
9
  "attention_probs_dropout_prob": 0.1,
10
  "classifier_dropout": null,
11
  "directionality": "bidi",
 
13
  "hidden_act": "gelu",
14
  "hidden_dropout_prob": 0.1,
15
  "hidden_size": 1024,
16
+ "list_con_hidden_size": 1792,
17
  "id2label": {
18
  "0": "LABEL_0"
19
  },
 
39
  "transformers_version": "4.45.2",
40
  "type_vocab_size": 2,
41
  "use_cache": true,
42
+ "vocab_size": 21128,
43
+ "cls_token_id": 101,
44
+ "sep_token_id": 102
45
  }
listconranker.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the "Software"), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+ from __future__ import annotations
20
+
21
+ import torch
22
+ from torch import nn
23
+ from torch.nn import functional as F
24
+ from transformers import (
25
+ PreTrainedModel,
26
+ BertModel,
27
+ BertConfig,
28
+ AutoTokenizer,
29
+ )
30
+ import os
31
+ from transformers.modeling_outputs import SequenceClassifierOutput
32
+ from typing import Union, List, Optional
33
+ from collections import defaultdict
34
+ import numpy as np
35
+ import math
36
+
37
+
38
+ class ListConRankerConfig(BertConfig):
39
+ """Configuration class for ListConRanker model."""
40
+
41
+ model_type = "ListConRanker"
42
+
43
+ def __init__(
44
+ self,
45
+ list_transformer_layers: int = 2,
46
+ list_con_hidden_size: int = 1792,
47
+ num_labels: int = 1,
48
+ cls_token_id: int = 101,
49
+ sep_token_id: int = 102,
50
+ **kwargs,
51
+ ):
52
+ super().__init__(**kwargs)
53
+ self.list_transformer_layers = list_transformer_layers
54
+ self.list_con_hidden_size = list_con_hidden_size
55
+ self.num_labels = num_labels
56
+ self.cls_token_id = cls_token_id
57
+ self.sep_token_id = sep_token_id
58
+
59
+ self.bert_config = BertConfig(**kwargs)
60
+ self.bert_config.output_hidden_states = True
61
+
62
+
63
+ class QueryEmbedding(nn.Module):
64
+ def __init__(self, config) -> None:
65
+ super().__init__()
66
+ self.query_embedding = nn.Embedding(2, config.list_con_hidden_size)
67
+ self.layerNorm = nn.LayerNorm(config.list_con_hidden_size)
68
+
69
+ def forward(self, x, tags):
70
+ query_embeddings = self.query_embedding(tags)
71
+ x += query_embeddings
72
+ x = self.layerNorm(x)
73
+ return x
74
+
75
+
76
+ class ListTransformer(nn.Module):
77
+ def __init__(self, num_layer, config) -> None:
78
+ super().__init__()
79
+ self.config = config
80
+ self.list_transformer_layer = nn.TransformerEncoderLayer(
81
+ config.list_con_hidden_size,
82
+ self.config.num_attention_heads,
83
+ batch_first=True,
84
+ activation=F.gelu,
85
+ norm_first=False,
86
+ )
87
+ self.list_transformer = nn.TransformerEncoder(
88
+ self.list_transformer_layer, num_layer
89
+ )
90
+ self.relu = nn.ReLU()
91
+ self.query_embedding = QueryEmbedding(config)
92
+
93
+ self.linear_score3 = nn.Linear(
94
+ config.list_con_hidden_size * 2, config.list_con_hidden_size
95
+ )
96
+ self.linear_score2 = nn.Linear(
97
+ config.list_con_hidden_size * 2, config.list_con_hidden_size
98
+ )
99
+ self.linear_score1 = nn.Linear(config.list_con_hidden_size * 2, 1)
100
+
101
+ def forward(
102
+ self, pair_features: torch.Tensor, pair_nums: List[int]
103
+ ) -> torch.Tensor:
104
+ batch_pair_features = pair_features.split(pair_nums)
105
+
106
+ pair_feature_query_passage_concat_list = []
107
+ for i in range(len(batch_pair_features)):
108
+ pair_feature_query = (
109
+ batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
110
+ )
111
+ pair_feature_passage = batch_pair_features[i][1:]
112
+ pair_feature_query_passage_concat_list.append(
113
+ torch.cat([pair_feature_query, pair_feature_passage], dim=1)
114
+ )
115
+ pair_feature_query_passage_concat = torch.cat(
116
+ pair_feature_query_passage_concat_list, dim=0
117
+ )
118
+
119
+ batch_pair_features = nn.utils.rnn.pad_sequence(
120
+ batch_pair_features, batch_first=True
121
+ )
122
+
123
+ query_embedding_tags = torch.zeros(
124
+ batch_pair_features.size(0),
125
+ batch_pair_features.size(1),
126
+ dtype=torch.long,
127
+ device=self.device,
128
+ )
129
+ query_embedding_tags[:, 0] = 1
130
+ batch_pair_features = self.query_embedding(
131
+ batch_pair_features, query_embedding_tags
132
+ )
133
+
134
+ mask = self.generate_attention_mask(pair_nums)
135
+ query_mask = self.generate_attention_mask_custom(pair_nums)
136
+ pair_list_features = self.list_transformer(
137
+ batch_pair_features, src_key_padding_mask=mask, mask=query_mask
138
+ )
139
+
140
+ output_pair_list_features = []
141
+ output_query_list_features = []
142
+ pair_features_after_transformer_list = []
143
+ for idx, pair_num in enumerate(pair_nums):
144
+ output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
145
+ output_query_list_features.append(pair_list_features[idx, 0, :])
146
+ pair_features_after_transformer_list.append(
147
+ pair_list_features[idx, :pair_num, :]
148
+ )
149
+
150
+ pair_features_after_transformer_cat_query_list = []
151
+ for idx, pair_num in enumerate(pair_nums):
152
+ query_ft = (
153
+ output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
154
+ )
155
+ pair_features_after_transformer_cat_query = torch.cat(
156
+ [query_ft, output_pair_list_features[idx]], dim=1
157
+ )
158
+ pair_features_after_transformer_cat_query_list.append(
159
+ pair_features_after_transformer_cat_query
160
+ )
161
+ pair_features_after_transformer_cat_query = torch.cat(
162
+ pair_features_after_transformer_cat_query_list, dim=0
163
+ )
164
+
165
+ pair_feature_query_passage_concat = self.relu(
166
+ self.linear_score2(pair_feature_query_passage_concat)
167
+ )
168
+ pair_features_after_transformer_cat_query = self.relu(
169
+ self.linear_score3(pair_features_after_transformer_cat_query)
170
+ )
171
+ final_ft = torch.cat(
172
+ [
173
+ pair_feature_query_passage_concat,
174
+ pair_features_after_transformer_cat_query,
175
+ ],
176
+ dim=1,
177
+ )
178
+ logits = self.linear_score1(final_ft).squeeze()
179
+ return logits, torch.cat(pair_features_after_transformer_list, dim=0)
180
+
181
+ def generate_attention_mask(self, pair_num):
182
+ max_len = max(pair_num)
183
+ batch_size = len(pair_num)
184
+ mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device)
185
+ for i, length in enumerate(pair_num):
186
+ mask[i, length:] = True
187
+ return mask
188
+
189
+ def generate_attention_mask_custom(self, pair_num):
190
+ max_len = max(pair_num)
191
+ mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device)
192
+ mask[0, 1:] = True
193
+ return mask
194
+
195
+
196
+ class ListConRankerModel(PreTrainedModel):
197
+ """
198
+ ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification.
199
+ """
200
+
201
+ config_class = ListConRankerConfig
202
+ base_model_prefix = "listconranker"
203
+
204
+ def __init__(self, config: ListConRankerConfig):
205
+ super().__init__(config)
206
+ self.config = config
207
+ self.num_labels = config.num_labels
208
+ self.hf_model = BertModel(config.bert_config)
209
+
210
+ self.sigmoid = nn.Sigmoid()
211
+
212
+ self.linear_in_embedding = nn.Linear(
213
+ config.hidden_size, config.list_con_hidden_size
214
+ )
215
+ self.list_transformer = ListTransformer(
216
+ config.list_transformer_layers,
217
+ config,
218
+ )
219
+
220
+ def forward(
221
+ self,
222
+ input_ids: torch.Tensor,
223
+ attention_mask: Optional[torch.Tensor] = None,
224
+ token_type_ids: Optional[torch.Tensor] = None,
225
+ position_ids: Optional[torch.Tensor] = None,
226
+ head_mask: Optional[torch.Tensor] = None,
227
+ inputs_embeds: Optional[torch.Tensor] = None,
228
+ labels: Optional[torch.Tensor] = None,
229
+ output_attentions: Optional[bool] = None,
230
+ output_hidden_states: Optional[bool] = None,
231
+ return_dict: Optional[bool] = None,
232
+ **kwargs,
233
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
234
+ if self.training:
235
+ raise NotImplementedError("Training not supported; use eval mode.")
236
+ device = input_ids.device
237
+ self.list_transformer.device = device
238
+ # Reorganize by unique queries and their passages
239
+ (
240
+ reorganized_input_ids,
241
+ reorganized_attention_mask,
242
+ reorganized_token_type_ids,
243
+ pair_nums,
244
+ group_indices,
245
+ ) = self._reorganize_inputs(input_ids, attention_mask, token_type_ids)
246
+
247
+ out = self.hf_model(
248
+ input_ids=reorganized_input_ids,
249
+ attention_mask=reorganized_attention_mask,
250
+ token_type_ids=reorganized_token_type_ids,
251
+ return_dict=True,
252
+ )
253
+ feats = out.last_hidden_state
254
+ pooled = self.average_pooling(feats, reorganized_attention_mask)
255
+ embedded = self.linear_in_embedding(pooled)
256
+ logits, _ = self.list_transformer(embedded, pair_nums)
257
+ probs = self.sigmoid(logits)
258
+
259
+ # Restore original order
260
+ sorted_probs = self._restore_original_order(probs, group_indices)
261
+ sorted_logits = self._restore_original_order(logits, group_indices)
262
+ if not return_dict:
263
+ return (sorted_probs, sorted_logits)
264
+
265
+ return SequenceClassifierOutput(
266
+ loss=None,
267
+ logits=sorted_logits,
268
+ hidden_states=out.hidden_states,
269
+ attentions=out.attentions,
270
+ )
271
+
272
+ def _reorganize_inputs(
273
+ self,
274
+ input_ids: torch.Tensor,
275
+ attention_mask: torch.Tensor,
276
+ token_type_ids: Optional[torch.Tensor],
277
+ ) -> tuple[
278
+ torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int], List[List[int]]
279
+ ]:
280
+ """
281
+ Group inputs by unique queries: for each query, produce [query] + its passages,
282
+ then flatten, pad, and return pair sizes and original indices mapping.
283
+ """
284
+ batch_size = input_ids.size(0)
285
+ # Structure: query_key -> {
286
+ # 'query': (seq, mask, tt),
287
+ # 'passages': [(seq, mask, tt), ...],
288
+ # 'indices': [original_index, ...]
289
+ # }
290
+ grouped = {}
291
+
292
+ for idx in range(batch_size):
293
+ seq = input_ids[idx]
294
+ mask = attention_mask[idx]
295
+ token_type_ids[idx] if token_type_ids is not None else torch.zeros_like(seq)
296
+
297
+ sep_idxs = (seq == self.config.sep_token_id).nonzero(as_tuple=True)[0]
298
+ if sep_idxs.numel() == 0:
299
+ raise ValueError(f"No SEP in sequence {idx}")
300
+ first_sep = sep_idxs[0].item()
301
+ second_sep = sep_idxs[1].item()
302
+
303
+ # Extract query and passage
304
+ q_seq = seq[: first_sep + 1]
305
+ q_mask = mask[: first_sep + 1]
306
+ q_tt = torch.zeros_like(q_seq)
307
+
308
+ p_seq = seq[first_sep : second_sep + 1]
309
+ p_mask = mask[first_sep : second_sep + 1]
310
+ p_seq = p_seq.clone()
311
+ p_seq[0] = self.config.cls_token_id
312
+ p_tt = torch.zeros_like(p_seq)
313
+
314
+ # Build key excluding CLS/SEP
315
+ key = tuple(
316
+ q_seq[
317
+ (q_seq != self.config.cls_token_id)
318
+ & (q_seq != self.config.sep_token_id)
319
+ ].tolist()
320
+ )
321
+
322
+ # truncation
323
+ q_seq = q_seq[: self.config.max_position_embeddings]
324
+ q_seq[-1] = self.config.sep_token_id
325
+ p_seq = p_seq[: self.config.max_position_embeddings]
326
+ p_seq[-1] = self.config.sep_token_id
327
+ q_mask = q_mask[: self.config.max_position_embeddings]
328
+ p_mask = p_mask[: self.config.max_position_embeddings]
329
+ q_tt = q_tt[: self.config.max_position_embeddings]
330
+ p_tt = p_tt[: self.config.max_position_embeddings]
331
+
332
+ if key not in grouped:
333
+ grouped[key] = {
334
+ "query": (q_seq, q_mask, q_tt),
335
+ "passages": [],
336
+ "indices": [],
337
+ }
338
+ grouped[key]["passages"].append((p_seq, p_mask, p_tt))
339
+ grouped[key]["indices"].append(idx)
340
+
341
+ # Flatten according to group insertion order
342
+ seqs, masks, tts, pair_nums, group_indices = [], [], [], [], []
343
+ for key, data in grouped.items():
344
+ q_seq, q_mask, q_tt = data["query"]
345
+ passages = data["passages"]
346
+ indices = data["indices"]
347
+ # record sizes and original positions
348
+ pair_nums.append(len(passages) + 1) # +1 for the query
349
+ group_indices.append(indices)
350
+
351
+ # append query then its passages
352
+ seqs.append(q_seq)
353
+ masks.append(q_mask)
354
+ tts.append(q_tt)
355
+ for p_seq, p_mask, p_tt in passages:
356
+ seqs.append(p_seq)
357
+ masks.append(p_mask)
358
+ tts.append(p_tt)
359
+
360
+ # Pad to uniform length
361
+ max_len = max(s.size(0) for s in seqs)
362
+ padded_seqs, padded_masks, padded_tts = [], [], []
363
+ for s, m, t in zip(seqs, masks, tts):
364
+ ps = torch.zeros(max_len, dtype=s.dtype, device=s.device)
365
+ pm = torch.zeros(max_len, dtype=m.dtype, device=m.device)
366
+ pt = torch.zeros(max_len, dtype=t.dtype, device=t.device)
367
+ ps[: s.size(0)] = s
368
+ pm[: m.size(0)] = m
369
+ pt[: t.size(0)] = t
370
+ padded_seqs.append(ps)
371
+ padded_masks.append(pm)
372
+ padded_tts.append(pt)
373
+
374
+ rid = torch.stack(padded_seqs)
375
+ ram = torch.stack(padded_masks)
376
+ rtt = torch.stack(padded_tts) if token_type_ids is not None else None
377
+
378
+ return rid, ram, rtt, pair_nums, group_indices
379
+
380
+ def _restore_original_order(
381
+ self,
382
+ logits: torch.Tensor,
383
+ group_indices: List[List[int]],
384
+ ) -> torch.Tensor:
385
+ """
386
+ Map flattened logits back so each original index gets its passage score.
387
+ """
388
+ out = torch.zeros(logits.size(0), dtype=logits.dtype, device=logits.device)
389
+ i = 0
390
+ for indices in group_indices:
391
+ for idx in indices:
392
+ out[idx] = logits[i]
393
+ i += 1
394
+ return out.reshape(-1, 1)
395
+
396
+ def average_pooling(self, hidden_state, attention_mask):
397
+ extended_attention_mask = (
398
+ attention_mask.unsqueeze(-1)
399
+ .expand(hidden_state.size())
400
+ .to(dtype=hidden_state.dtype)
401
+ )
402
+ masked_hidden_state = hidden_state * extended_attention_mask
403
+ sum_embeddings = torch.sum(masked_hidden_state, dim=1)
404
+ sum_mask = extended_attention_mask.sum(dim=1)
405
+ return sum_embeddings / sum_mask
406
+
407
+ @classmethod
408
+ def from_pretrained(
409
+ cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs
410
+ ):
411
+ model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
412
+ model.hf_model = BertModel.from_pretrained(
413
+ model_name_or_path, config=model.config.bert_config, **kwargs
414
+ )
415
+
416
+ linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
417
+ transformer_path = os.path.join(model_name_or_path, "list_transformer.pt")
418
+
419
+ try:
420
+ model.linear_in_embedding.load_state_dict(torch.load(linear_path))
421
+ model.list_transformer.load_state_dict(torch.load(transformer_path))
422
+ except FileNotFoundError as e:
423
+ raise e
424
+
425
+ return model
426
+
427
+ def multi_passage(
428
+ self,
429
+ sentences: List[List[str]],
430
+ batch_size: int = 32,
431
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
432
+ "ByteDance/ListConRanker"
433
+ ),
434
+ ):
435
+ """
436
+ Process multiple passages for each query.
437
+ :param sentences: List of lists, where each inner list contains sentences for a query.
438
+ :return: Tensor of logits for each passage.
439
+ """
440
+ pairs = []
441
+ for batch in sentences:
442
+ if len(batch) < 2:
443
+ raise ValueError("Each query must have at least one passage.")
444
+ query = batch[0]
445
+ passages = batch[1:]
446
+ for passage in passages:
447
+ pairs.append((query, passage))
448
+
449
+ total_batches = (len(pairs) + batch_size - 1) // batch_size
450
+ total_logits = torch.zeros(len(pairs), dtype=torch.float, device=self.device)
451
+ for batch in range(total_batches):
452
+ batch_pairs = pairs[batch * batch_size : (batch + 1) * batch_size]
453
+ inputs = tokenizer(
454
+ batch_pairs,
455
+ padding=True,
456
+ truncation=False,
457
+ return_tensors="pt",
458
+ )
459
+
460
+ for k, v in inputs.items():
461
+ inputs[k] = v.to(self.device)
462
+
463
+ logits = self(**inputs)[0]
464
+ total_logits[batch * batch_size : (batch + 1) * batch_size] = (
465
+ logits.squeeze(1)
466
+ )
467
+ return total_logits
468
+
469
+ def multi_passage_in_iterative_inference(
470
+ self,
471
+ sentences: List[str],
472
+ stop_num: int = 20,
473
+ decrement_rate: float = 0.2,
474
+ min_filter_num: int = 10,
475
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
476
+ "ByteDance/ListConRanker"
477
+ ),
478
+ ):
479
+ """
480
+ Process multiple passages for one query in iterative inference.
481
+ :param sentences: List contains sentences for a query.
482
+ :return: Tensor of logits for each passage.
483
+ """
484
+ if stop_num < 1:
485
+ raise ValueError("stop_num must be greater than 0")
486
+ if decrement_rate <= 0 or decrement_rate >= 1:
487
+ raise ValueError("decrement_rate must be in (0, 1)")
488
+ if min_filter_num < 1:
489
+ raise ValueError("min_filter_num must be greater than 0")
490
+
491
+ query = sentences[0]
492
+ passage = sentences[1:]
493
+
494
+ filter_times = 0
495
+ passage2score = defaultdict(list)
496
+ while len(passage) > stop_num:
497
+ batch = [[query] + passage]
498
+ pred_scores = self.multi_passage(
499
+ batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
500
+ ).tolist()
501
+ pred_scores_argsort = np.argsort(
502
+ pred_scores
503
+ ).tolist() # Sort in increasing order
504
+
505
+ passage_len = len(passage)
506
+ to_filter_num = math.ceil(passage_len * decrement_rate)
507
+ if to_filter_num < min_filter_num:
508
+ to_filter_num = min_filter_num
509
+
510
+ have_filter_num = 0
511
+ while have_filter_num < to_filter_num:
512
+ idx = pred_scores_argsort[have_filter_num]
513
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
514
+ have_filter_num += 1
515
+ while (
516
+ pred_scores[pred_scores_argsort[have_filter_num - 1]]
517
+ == pred_scores[pred_scores_argsort[have_filter_num]]
518
+ ):
519
+ idx = pred_scores_argsort[have_filter_num]
520
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
521
+ have_filter_num += 1
522
+ next_passage = []
523
+ next_passage_idx = have_filter_num
524
+ while next_passage_idx < len(passage):
525
+ idx = pred_scores_argsort[next_passage_idx]
526
+ next_passage.append(passage[idx])
527
+ next_passage_idx += 1
528
+ passage = next_passage
529
+ filter_times += 1
530
+
531
+ batch = [[query] + passage]
532
+ pred_scores = self.multi_passage(
533
+ batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
534
+ ).tolist()
535
+
536
+ cnt = 0
537
+ while cnt < len(passage):
538
+ passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
539
+ cnt += 1
540
+
541
+ passage = sentences[1:]
542
+ final_score = []
543
+ for i in range(len(passage)):
544
+ p = passage[i]
545
+ final_score.append(passage2score[p][0])
546
+ return final_score
tokenizer_config.json CHANGED
@@ -47,7 +47,7 @@
47
  "do_lower_case": true,
48
  "mask_token": "[MASK]",
49
  "max_length": 512,
50
- "model_max_length": 1000000000000000019884624838656,
51
  "never_split": null,
52
  "pad_to_multiple_of": null,
53
  "pad_token": "[PAD]",
 
47
  "do_lower_case": true,
48
  "mask_token": "[MASK]",
49
  "max_length": 512,
50
+ "model_max_length": 512,
51
  "never_split": null,
52
  "pad_to_multiple_of": null,
53
  "pad_token": "[PAD]",