Roman Solomatin commited on
Commit
c3e4b3b
·
unverified ·
1 Parent(s): f7a361f

finish integration

Browse files
Files changed (4) hide show
  1. README.md +43 -7
  2. config.json +3 -1
  3. listconranker.py +201 -41
  4. tokenizer_config.json +1 -1
README.md CHANGED
@@ -1,6 +1,9 @@
1
  ---
2
  tags:
3
  - mteb
 
 
 
4
  model-index:
5
  - name: ListConRanker
6
  results:
@@ -103,9 +106,9 @@ To reduce the discrepancy between training and inference, we propose iterative i
103
 
104
  ## How to use
105
  ```python
106
- from modules.listconranker import ListConRanker
107
 
108
- reranker = ListConRanker('./', use_fp16=True, list_transformer_layer=2)
109
 
110
  # [query, passages_1, passage_2, ..., passage_n]
111
  batch = [
@@ -125,15 +128,48 @@ batch = [
125
  ]
126
 
127
  # for conventional inference, please manage the batch size by yourself
128
- scores = reranker.compute_score(batch)
129
  print(scores)
130
  # [[0.5126953125, 0.331298828125, 0.3642578125], [0.63671875, 0.71630859375, 0.42822265625, 0.35302734375]]
131
 
132
- # for iterative inferfence, only a batch size of 1 is supported
133
- # the scores do not indicate similarity but are intended only for ranking
134
- scores = reranker.iterative_inference(batch[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  print(scores)
136
- # [0.5126953125, 0.331298828125, 0.3642578125]
137
  ```
138
 
139
  To reproduce the results with iterative inference, please run:
 
1
  ---
2
  tags:
3
  - mteb
4
+ - sentence-transformers
5
+ - transformers
6
+ pipeline_tag: text-ranking
7
  model-index:
8
  - name: ListConRanker
9
  results:
 
106
 
107
  ## How to use
108
  ```python
109
+ from transfoermers import AutoModelForSequenceClassification
110
 
111
+ reranker = AutoModelForSequenceClassification('ByteDance/ListConRanker', trust_remote_code=True)
112
 
113
  # [query, passages_1, passage_2, ..., passage_n]
114
  batch = [
 
128
  ]
129
 
130
  # for conventional inference, please manage the batch size by yourself
131
+ scores = reranker.multi_passage(batch)
132
  print(scores)
133
  # [[0.5126953125, 0.331298828125, 0.3642578125], [0.63671875, 0.71630859375, 0.42822265625, 0.35302734375]]
134
 
135
+ inputs = tokenizer(
136
+ [
137
+ [
138
+ "query 1",
139
+ "passage_11",
140
+ ],
141
+ [
142
+ "query_2",
143
+ "passage_21",
144
+ ]
145
+
146
+ ],
147
+ return_tensors="pt",
148
+ padding=True,
149
+ )
150
+ probs, logits = model(**inputs)
151
+ print(probs)
152
+ # tensor([[0.4359], [0.3840]], grad_fn=<ViewBackward0>)
153
+ ```
154
+ or using the `sentence_transformers` library:
155
+ ```python
156
+ from sentence_transformers import CrossEncoder
157
+
158
+ model = CrossEncoder('ByteDance/ListConRanker', trust_remote_code=True)
159
+
160
+ inputs = [
161
+ [
162
+ "query 1",
163
+ "passage_11",
164
+ ],
165
+ [
166
+ "query_2",
167
+ "passage_21",
168
+ ]
169
+ ]
170
+ scores = model.predict(inputs)
171
  print(scores)
172
+ # [0.4359, 0.3840, 0.3231]
173
  ```
174
 
175
  To reproduce the results with iterative inference, please run:
config.json CHANGED
@@ -39,5 +39,7 @@
39
  "transformers_version": "4.45.2",
40
  "type_vocab_size": 2,
41
  "use_cache": true,
42
- "vocab_size": 21128
 
 
43
  }
 
39
  "transformers_version": "4.45.2",
40
  "type_vocab_size": 2,
41
  "use_cache": true,
42
+ "vocab_size": 21128,
43
+ "cls_token_id": 101,
44
+ "sep_token_id": 102
45
  }
listconranker.py CHANGED
@@ -16,22 +16,18 @@
16
  # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
  # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
  # OTHER DEALINGS IN THE SOFTWARE.
 
19
 
20
- import math
21
  import torch
22
  from torch import nn
23
  from torch.nn import functional as F
24
- import numpy as np
25
  from transformers import (
26
- AutoTokenizer,
27
- is_torch_npu_available,
28
- AutoModel,
29
  PreTrainedModel,
30
- PretrainedConfig,
31
- AutoConfig,
32
  BertModel,
33
  BertConfig,
 
34
  )
 
35
  from transformers.modeling_outputs import SequenceClassifierOutput
36
  from typing import Union, List, Optional
37
 
@@ -46,12 +42,16 @@ class ListConRankerConfig(BertConfig):
46
  list_transformer_layers: int = 2,
47
  list_con_hidden_size: int = 1792,
48
  num_labels: int = 1,
 
 
49
  **kwargs,
50
  ):
51
  super().__init__(**kwargs)
52
  self.list_transformer_layers = list_transformer_layers
53
  self.list_con_hidden_size = list_con_hidden_size
54
  self.num_labels = num_labels
 
 
55
 
56
  self.bert_config = BertConfig(**kwargs)
57
  self.bert_config.output_hidden_states = True
@@ -75,7 +75,7 @@ class ListTransformer(nn.Module):
75
  super().__init__()
76
  self.config = config
77
  self.list_transformer_layer = nn.TransformerEncoderLayer(
78
- 1792,
79
  self.config.num_attention_heads,
80
  batch_first=True,
81
  activation=F.gelu,
@@ -213,11 +213,10 @@ class ListConRankerModel(PreTrainedModel):
213
  config.list_transformer_layers,
214
  config,
215
  )
216
- self.sep_token_id = 102 # [SEP] token ID
217
 
218
  def forward(
219
  self,
220
- input_ids: Optional[torch.Tensor] = None,
221
  attention_mask: Optional[torch.Tensor] = None,
222
  token_type_ids: Optional[torch.Tensor] = None,
223
  position_ids: Optional[torch.Tensor] = None,
@@ -228,36 +227,157 @@ class ListConRankerModel(PreTrainedModel):
228
  output_hidden_states: Optional[bool] = None,
229
  return_dict: Optional[bool] = None,
230
  **kwargs,
231
- ) -> Union[SequenceClassifierOutput, tuple]:
232
- # Get device
233
- device = input_ids.device if input_ids is not None else inputs_embeds.device
234
- self.list_transformer.device = device
235
-
236
- # Forward through base model
237
  if self.training:
238
- pass
239
- else:
240
- ranker_out = self.hf_model(
241
- input_ids=input_ids,
242
- attention_mask=attention_mask,
243
- token_type_ids=token_type_ids,
244
- position_ids=position_ids,
245
- head_mask=head_mask,
246
- inputs_embeds=inputs_embeds,
247
- output_attentions=output_attentions,
248
- return_dict=True,
249
- )
250
- last_hidden_state = ranker_out.last_hidden_state
251
-
252
- pair_features = self.average_pooling(last_hidden_state, attention_mask)
253
- pair_features = self.linear_in_embedding(pair_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- logits, pair_features_after_list_transformer = self.list_transformer(
256
- pair_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
258
- logits = self.sigmoid(logits)
259
 
260
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  def average_pooling(self, hidden_state, attention_mask):
263
  extended_attention_mask = (
@@ -275,15 +395,55 @@ class ListConRankerModel(PreTrainedModel):
275
  cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs
276
  ):
277
  model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
 
 
 
278
 
279
- # Load custom weights
280
- linear_path = f"{model_name_or_path}/linear_in_embedding.pt"
281
- transformer_path = f"{model_name_or_path}/list_transformer.pt"
282
 
283
  try:
284
  model.linear_in_embedding.load_state_dict(torch.load(linear_path))
285
  model.list_transformer.load_state_dict(torch.load(transformer_path))
286
- except FileNotFoundError:
287
- print(f"Warning: Could not load custom weights from {model_name_or_path}")
288
 
289
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
  # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
  # OTHER DEALINGS IN THE SOFTWARE.
19
+ from __future__ import annotations
20
 
 
21
  import torch
22
  from torch import nn
23
  from torch.nn import functional as F
 
24
  from transformers import (
 
 
 
25
  PreTrainedModel,
 
 
26
  BertModel,
27
  BertConfig,
28
+ AutoTokenizer,
29
  )
30
+ import os
31
  from transformers.modeling_outputs import SequenceClassifierOutput
32
  from typing import Union, List, Optional
33
 
 
42
  list_transformer_layers: int = 2,
43
  list_con_hidden_size: int = 1792,
44
  num_labels: int = 1,
45
+ cls_token_id: int = 101,
46
+ sep_token_id: int = 102,
47
  **kwargs,
48
  ):
49
  super().__init__(**kwargs)
50
  self.list_transformer_layers = list_transformer_layers
51
  self.list_con_hidden_size = list_con_hidden_size
52
  self.num_labels = num_labels
53
+ self.cls_token_id = cls_token_id
54
+ self.sep_token_id = sep_token_id
55
 
56
  self.bert_config = BertConfig(**kwargs)
57
  self.bert_config.output_hidden_states = True
 
75
  super().__init__()
76
  self.config = config
77
  self.list_transformer_layer = nn.TransformerEncoderLayer(
78
+ config.list_con_hidden_size,
79
  self.config.num_attention_heads,
80
  batch_first=True,
81
  activation=F.gelu,
 
213
  config.list_transformer_layers,
214
  config,
215
  )
 
216
 
217
  def forward(
218
  self,
219
+ input_ids: torch.Tensor,
220
  attention_mask: Optional[torch.Tensor] = None,
221
  token_type_ids: Optional[torch.Tensor] = None,
222
  position_ids: Optional[torch.Tensor] = None,
 
227
  output_hidden_states: Optional[bool] = None,
228
  return_dict: Optional[bool] = None,
229
  **kwargs,
230
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
 
 
 
 
 
231
  if self.training:
232
+ raise NotImplementedError("Training not supported; use eval mode.")
233
+ device = input_ids.device
234
+ self.list_transformer.device = device
235
+ # Reorganize by unique queries and their passages
236
+ (
237
+ reorganized_input_ids,
238
+ reorganized_attention_mask,
239
+ reorganized_token_type_ids,
240
+ pair_nums,
241
+ group_indices,
242
+ ) = self._reorganize_inputs(input_ids, attention_mask, token_type_ids)
243
+
244
+ out = self.hf_model(
245
+ input_ids=reorganized_input_ids,
246
+ attention_mask=reorganized_attention_mask,
247
+ token_type_ids=reorganized_token_type_ids,
248
+ return_dict=True,
249
+ )
250
+ feats = out.last_hidden_state
251
+ pooled = self.average_pooling(feats, reorganized_attention_mask)
252
+ embedded = self.linear_in_embedding(pooled)
253
+ logits, _ = self.list_transformer(embedded, pair_nums)
254
+ probs = self.sigmoid(logits)
255
+
256
+ # Restore original order
257
+ sorted_probs = self._restore_original_order(probs, group_indices)
258
+ sorted_logits = self._restore_original_order(logits, group_indices)
259
+ if not return_dict:
260
+ return (sorted_probs, sorted_logits)
261
+
262
+ return SequenceClassifierOutput(
263
+ loss=None,
264
+ logits=sorted_logits,
265
+ hidden_states=out.hidden_states,
266
+ attentions=out.attentions,
267
+ )
268
 
269
+ def _reorganize_inputs(
270
+ self,
271
+ input_ids: torch.Tensor,
272
+ attention_mask: torch.Tensor,
273
+ token_type_ids: Optional[torch.Tensor],
274
+ ) -> tuple[
275
+ torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int], List[List[int]]
276
+ ]:
277
+ """
278
+ Group inputs by unique queries: for each query, produce [query] + its passages,
279
+ then flatten, pad, and return pair sizes and original indices mapping.
280
+ """
281
+ batch_size = input_ids.size(0)
282
+ # Structure: query_key -> {
283
+ # 'query': (seq, mask, tt),
284
+ # 'passages': [(seq, mask, tt), ...],
285
+ # 'indices': [original_index, ...]
286
+ # }
287
+ grouped = {}
288
+
289
+ for idx in range(batch_size):
290
+ seq = input_ids[idx]
291
+ mask = attention_mask[idx]
292
+ token_type_ids[idx] if token_type_ids is not None else torch.zeros_like(seq)
293
+
294
+ sep_idxs = (seq == self.config.sep_token_id).nonzero(as_tuple=True)[0]
295
+ if sep_idxs.numel() == 0:
296
+ raise ValueError(f"No SEP in sequence {idx}")
297
+ first_sep = sep_idxs[0].item()
298
+
299
+ # Extract query and passage
300
+ q_seq = seq[: first_sep + 1]
301
+ q_mask = mask[: first_sep + 1]
302
+ q_tt = torch.zeros_like(q_seq)
303
+
304
+ p_seq = seq[first_sep:]
305
+ p_mask = mask[first_sep:]
306
+ p_seq = p_seq.clone()
307
+ p_seq[0] = self.config.cls_token_id
308
+ p_tt = torch.zeros_like(p_seq)
309
+
310
+ # Build key excluding CLS/SEP
311
+ key = tuple(
312
+ q_seq[
313
+ (q_seq != self.config.cls_token_id)
314
+ & (q_seq != self.config.sep_token_id)
315
+ ].tolist()
316
  )
 
317
 
318
+ if key not in grouped:
319
+ grouped[key] = {
320
+ "query": (q_seq, q_mask, q_tt),
321
+ "passages": [],
322
+ "indices": [],
323
+ }
324
+ grouped[key]["passages"].append((p_seq, p_mask, p_tt))
325
+ grouped[key]["indices"].append(idx)
326
+
327
+ # Flatten according to group insertion order
328
+ seqs, masks, tts, pair_nums, group_indices = [], [], [], [], []
329
+ for key, data in grouped.items():
330
+ q_seq, q_mask, q_tt = data["query"]
331
+ passages = data["passages"]
332
+ indices = data["indices"]
333
+ # record sizes and original positions
334
+ pair_nums.append(len(passages) + 1) # +1 for the query
335
+ group_indices.append(indices)
336
+
337
+ # append query then its passages
338
+ seqs.append(q_seq)
339
+ masks.append(q_mask)
340
+ tts.append(q_tt)
341
+ for p_seq, p_mask, p_tt in passages:
342
+ seqs.append(p_seq)
343
+ masks.append(p_mask)
344
+ tts.append(p_tt)
345
+
346
+ # Pad to uniform length
347
+ max_len = max(s.size(0) for s in seqs)
348
+ padded_seqs, padded_masks, padded_tts = [], [], []
349
+ for s, m, t in zip(seqs, masks, tts):
350
+ ps = torch.zeros(max_len, dtype=s.dtype, device=s.device)
351
+ pm = torch.zeros(max_len, dtype=m.dtype, device=m.device)
352
+ pt = torch.zeros(max_len, dtype=t.dtype, device=t.device)
353
+ ps[: s.size(0)] = s
354
+ pm[: m.size(0)] = m
355
+ pt[: t.size(0)] = t
356
+ padded_seqs.append(ps)
357
+ padded_masks.append(pm)
358
+ padded_tts.append(pt)
359
+
360
+ rid = torch.stack(padded_seqs)
361
+ ram = torch.stack(padded_masks)
362
+ rtt = torch.stack(padded_tts) if token_type_ids is not None else None
363
+
364
+ return rid, ram, rtt, pair_nums, group_indices
365
+
366
+ def _restore_original_order(
367
+ self,
368
+ logits: torch.Tensor,
369
+ group_indices: List[List[int]],
370
+ ) -> torch.Tensor:
371
+ """
372
+ Map flattened logits back so each original index gets its passage score.
373
+ """
374
+ out = torch.zeros(logits.size(0), dtype=logits.dtype, device=logits.device)
375
+ i = 0
376
+ for indices in group_indices:
377
+ for idx in indices:
378
+ out[idx] = logits[i]
379
+ i += 1
380
+ return out.reshape(-1, 1)
381
 
382
  def average_pooling(self, hidden_state, attention_mask):
383
  extended_attention_mask = (
 
395
  cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs
396
  ):
397
  model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
398
+ model.hf_model = BertModel.from_pretrained(
399
+ model_name_or_path, config=model.config.bert_config
400
+ )
401
 
402
+ linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
403
+ transformer_path = os.path.join(model_name_or_path, "list_transformer.pt")
 
404
 
405
  try:
406
  model.linear_in_embedding.load_state_dict(torch.load(linear_path))
407
  model.list_transformer.load_state_dict(torch.load(transformer_path))
408
+ except FileNotFoundError as e:
409
+ raise e
410
 
411
  return model
412
+
413
+ def multi_passage(
414
+ self,
415
+ sentences: List[List[str]],
416
+ batch_size: int = 32,
417
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
418
+ "ByteDance/ListConRanker"
419
+ ),
420
+ ):
421
+ """
422
+ Process multiple passages for each query.
423
+ :param sentences: List of lists, where each inner list contains sentences for a query.
424
+ :return: Tensor of logits for each passage.
425
+ """
426
+ pairs = []
427
+ for batch in sentences:
428
+ if len(batch) < 2:
429
+ raise ValueError("Each query must have at least one passage.")
430
+ query = batch[0]
431
+ passages = batch[1:]
432
+ for passage in passages:
433
+ pairs.append((query, passage))
434
+
435
+ total_batches = (len(pairs) + batch_size - 1) // batch_size
436
+ total_logits = torch.zeros(len(pairs), dtype=torch.float, device=self.device)
437
+ for batch in range(total_batches):
438
+ batch_pairs = pairs[batch * batch_size : (batch + 1) * batch_size]
439
+ inputs = tokenizer(
440
+ batch_pairs,
441
+ padding=True,
442
+ truncation=True,
443
+ return_tensors="pt",
444
+ )
445
+ logits = self(**inputs)[0]
446
+ total_logits[batch * batch_size : (batch + 1) * batch_size] = (
447
+ logits.squeeze(1)
448
+ )
449
+ return total_logits
tokenizer_config.json CHANGED
@@ -47,7 +47,7 @@
47
  "do_lower_case": true,
48
  "mask_token": "[MASK]",
49
  "max_length": 512,
50
- "model_max_length": 1000000000000000019884624838656,
51
  "never_split": null,
52
  "pad_to_multiple_of": null,
53
  "pad_token": "[PAD]",
 
47
  "do_lower_case": true,
48
  "mask_token": "[MASK]",
49
  "max_length": 512,
50
+ "model_max_length": 512,
51
  "never_split": null,
52
  "pad_to_multiple_of": null,
53
  "pad_token": "[PAD]",