Roman Solomatin commited on
Commit
5c7e478
·
unverified ·
1 Parent(s): 2852d45

start implementing

Browse files
Files changed (2) hide show
  1. config.json +7 -2
  2. listconranker.py +252 -0
config.json CHANGED
@@ -1,14 +1,19 @@
1
  {
2
- "architectures": [
3
  "BertModel"
4
  ],
 
 
 
 
5
  "attention_probs_dropout_prob": 0.1,
6
  "classifier_dropout": null,
7
  "directionality": "bidi",
8
  "gradient_checkpointing": false,
9
  "hidden_act": "gelu",
10
  "hidden_dropout_prob": 0.1,
11
- "hidden_size": 1024,
 
12
  "id2label": {
13
  "0": "LABEL_0"
14
  },
 
1
  {
2
+ "listconranker": [
3
  "BertModel"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "listconranker.ListConRankerConfig",
7
+ "AutoModelForSequenceClassification": "listconranker.ListConRankerModel"
8
+ },
9
  "attention_probs_dropout_prob": 0.1,
10
  "classifier_dropout": null,
11
  "directionality": "bidi",
12
  "gradient_checkpointing": false,
13
  "hidden_act": "gelu",
14
  "hidden_dropout_prob": 0.1,
15
+ "hidden_size": 1792,
16
+ "base_hidden_size": 1024,
17
  "id2label": {
18
  "0": "LABEL_0"
19
  },
listconranker.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the "Software"), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import math
21
+ import torch
22
+ from torch import nn
23
+ from torch.nn import functional as F
24
+ import numpy as np
25
+ from transformers import (
26
+ AutoTokenizer,
27
+ is_torch_npu_available,
28
+ AutoModel,
29
+ PreTrainedModel,
30
+ PretrainedConfig,
31
+ AutoConfig,
32
+ BertModel,
33
+ BertConfig
34
+ )
35
+ from transformers.modeling_outputs import SequenceClassifierOutput
36
+ from typing import Union, List, Optional
37
+
38
+
39
+ class ListConRankerConfig(PretrainedConfig):
40
+ """Configuration class for ListConRanker model."""
41
+
42
+ model_type = "listconranker"
43
+
44
+ def __init__(
45
+ self,
46
+ list_transformer_layers: int = 2,
47
+ num_attention_heads: int = 8,
48
+ hidden_size: int = 1792,
49
+ base_hidden_size: int = 1024,
50
+ num_labels: int = 1,
51
+ **kwargs
52
+ ):
53
+ super().__init__(**kwargs)
54
+ self.list_transformer_layers = list_transformer_layers
55
+ self.num_attention_heads = num_attention_heads
56
+ self.hidden_size = hidden_size
57
+ self.base_hidden_size = base_hidden_size
58
+ self.num_labels = num_labels
59
+
60
+ self.bert_config = BertConfig(**kwargs)
61
+ self.bert_config.output_hidden_states = True
62
+
63
+ class QueryEmbedding(nn.Module):
64
+ def __init__(self, config) -> None:
65
+ super().__init__()
66
+ self.query_embedding = nn.Embedding(2, config.hidden_size)
67
+ self.layerNorm = nn.LayerNorm(config.hidden_size)
68
+
69
+ def forward(self, x, tags):
70
+ query_embeddings = self.query_embedding(tags)
71
+ x += query_embeddings
72
+ x = self.layerNorm(x)
73
+ return x
74
+
75
+ class ListTransformer(nn.Module):
76
+ def __init__(self, num_layer, config) -> None:
77
+ super().__init__()
78
+ self.config = config
79
+ self.list_transformer_layer = nn.TransformerEncoderLayer(1792, self.config.num_attention_heads, batch_first=True, activation=F.gelu, norm_first=False)
80
+ self.list_transformer = nn.TransformerEncoder(self.list_transformer_layer, num_layer)
81
+ self.relu = nn.ReLU()
82
+ self.query_embedding = QueryEmbedding(config)
83
+
84
+ self.linear_score3 = nn.Linear(config.hidden_size * 2, config.hidden_size)
85
+ self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size)
86
+ self.linear_score1 = nn.Linear(config.hidden_size * 2, 1)
87
+
88
+ def forward(self, pair_features, pair_nums):
89
+ pair_nums = [x + 1 for x in pair_nums]
90
+ batch_pair_features = pair_features.split(pair_nums)
91
+
92
+ pair_feature_query_passage_concat_list = []
93
+ for i in range(len(batch_pair_features)):
94
+ pair_feature_query = batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
95
+ pair_feature_passage = batch_pair_features[i][1:]
96
+ pair_feature_query_passage_concat_list.append(torch.cat([pair_feature_query, pair_feature_passage], dim=1))
97
+ pair_feature_query_passage_concat = torch.cat(pair_feature_query_passage_concat_list, dim=0)
98
+
99
+ batch_pair_features = nn.utils.rnn.pad_sequence(batch_pair_features, batch_first=True)
100
+
101
+ query_embedding_tags = torch.zeros(batch_pair_features.size(0), batch_pair_features.size(1), dtype=torch.long, device=self.device)
102
+ query_embedding_tags[:, 0] = 1
103
+ batch_pair_features = self.query_embedding(batch_pair_features, query_embedding_tags)
104
+
105
+ mask = self.generate_attention_mask(pair_nums)
106
+ query_mask = self.generate_attention_mask_custom(pair_nums)
107
+ pair_list_features = self.list_transformer(batch_pair_features, src_key_padding_mask=mask, mask=query_mask)
108
+
109
+ output_pair_list_features = []
110
+ output_query_list_features = []
111
+ pair_features_after_transformer_list = []
112
+ for idx, pair_num in enumerate(pair_nums):
113
+ output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
114
+ output_query_list_features.append(pair_list_features[idx, 0, :])
115
+ pair_features_after_transformer_list.append(pair_list_features[idx, :pair_num, :])
116
+
117
+ pair_features_after_transformer_cat_query_list = []
118
+ for idx, pair_num in enumerate(pair_nums):
119
+ query_ft = output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
120
+ pair_features_after_transformer_cat_query = torch.cat([query_ft, output_pair_list_features[idx]], dim=1)
121
+ pair_features_after_transformer_cat_query_list.append(pair_features_after_transformer_cat_query)
122
+ pair_features_after_transformer_cat_query = torch.cat(pair_features_after_transformer_cat_query_list, dim=0)
123
+
124
+ pair_feature_query_passage_concat = self.relu(self.linear_score2(pair_feature_query_passage_concat))
125
+ pair_features_after_transformer_cat_query = self.relu(self.linear_score3(pair_features_after_transformer_cat_query))
126
+ final_ft = torch.cat([pair_feature_query_passage_concat, pair_features_after_transformer_cat_query], dim=1)
127
+ logits = self.linear_score1(final_ft).squeeze()
128
+
129
+ return logits, torch.cat(pair_features_after_transformer_list, dim=0)
130
+
131
+ def generate_attention_mask(self, pair_num):
132
+ max_len = max(pair_num)
133
+ batch_size = len(pair_num)
134
+ mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device)
135
+ for i, length in enumerate(pair_num):
136
+ mask[i, length:] = True
137
+ return mask
138
+
139
+ def generate_attention_mask_custom(self, pair_num):
140
+ max_len = max(pair_num)
141
+ mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device)
142
+ mask[0, 1:] = True
143
+ return mask
144
+
145
+
146
+ class ListConRankerModel(PreTrainedModel):
147
+ """
148
+ ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification.
149
+ """
150
+ config_class = ListConRankerConfig
151
+ base_model_prefix = "listconranker"
152
+
153
+ def __init__(self, config: ListConRankerConfig):
154
+ super().__init__(config)
155
+ self.config = config
156
+ self.num_labels = config.num_labels
157
+ self.hf_model = BertModel(config)
158
+
159
+ self.sigmoid = nn.Sigmoid()
160
+
161
+ self.linear_in_embedding = nn.Linear(config.base_hidden_size, config.hidden_size)
162
+ self.list_transformer = ListTransformer(
163
+ config.list_transformer_layers,
164
+ config,
165
+ )
166
+
167
+ def forward(
168
+ self,
169
+ input_ids: Optional[torch.Tensor] = None,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ token_type_ids: Optional[torch.Tensor] = None,
172
+ position_ids: Optional[torch.Tensor] = None,
173
+ head_mask: Optional[torch.Tensor] = None,
174
+ inputs_embeds: Optional[torch.Tensor] = None,
175
+ labels: Optional[torch.Tensor] = None,
176
+ output_attentions: Optional[bool] = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ return_dict: Optional[bool] = None,
179
+ pair_num: Optional[torch.Tensor] = None,
180
+ **kwargs
181
+ ) -> Union[SequenceClassifierOutput, tuple]:
182
+ # Handle pair_num parameter
183
+ if pair_num is not None:
184
+ pair_nums = pair_num.tolist()
185
+ else:
186
+ # Default behavior if pair_num is not provided
187
+ batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
188
+ pair_nums = [1] * batch_size
189
+
190
+ # Get device
191
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
192
+ self.list_transformer.device = device
193
+
194
+ # Forward through base model
195
+ if self.training:
196
+ pass
197
+ else:
198
+ split_batch = 400
199
+ if sum(pair_nums) > split_batch:
200
+ last_hidden_state_list = []
201
+ input_ids_list = input_ids.split(split_batch)
202
+ attention_mask_list = attention_mask.split(split_batch)
203
+ for i in range(len(input_ids_list)):
204
+ last_hidden_state = self.hf_model(
205
+ input_ids=input_ids_list[i],
206
+ attention_mask=attention_mask_list[i],
207
+ return_dict=True).hidden_states[-1]
208
+ last_hidden_state_list.append(last_hidden_state)
209
+ last_hidden_state = torch.cat(last_hidden_state_list, dim=0)
210
+ else:
211
+ ranker_out = self.hf_model(
212
+ input_ids=input_ids,
213
+ attention_mask=attention_mask,
214
+ token_type_ids=token_type_ids,
215
+ position_ids=position_ids,
216
+ head_mask=head_mask,
217
+ inputs_embeds=inputs_embeds,
218
+ output_attentions=output_attentions,
219
+ return_dict=True)
220
+ last_hidden_state = ranker_out.last_hidden_state
221
+
222
+ pair_features = self.average_pooling(last_hidden_state, attention_mask)
223
+ pair_features = self.linear_in_embedding(pair_features)
224
+
225
+ logits, pair_features_after_list_transformer = self.list_transformer(pair_features, pair_nums)
226
+ logits = self.sigmoid(logits)
227
+
228
+ return logits
229
+
230
+ def average_pooling(self, hidden_state, attention_mask):
231
+ extended_attention_mask = attention_mask.unsqueeze(-1).expand(hidden_state.size()).to(dtype=hidden_state.dtype)
232
+ masked_hidden_state = hidden_state * extended_attention_mask
233
+ sum_embeddings = torch.sum(masked_hidden_state, dim=1)
234
+ sum_mask = extended_attention_mask.sum(dim=1)
235
+ return sum_embeddings / sum_mask
236
+
237
+ @classmethod
238
+ def from_pretrained(cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs):
239
+ model = super().from_pretrained(
240
+ model_name_or_path,config=config, **kwargs)
241
+
242
+ # Load custom weights
243
+ linear_path = f"{model_name_or_path}/linear_in_embedding.pt"
244
+ transformer_path = f"{model_name_or_path}/list_transformer.pt"
245
+
246
+ try:
247
+ model.linear_in_embedding.load_state_dict(torch.load(linear_path))
248
+ model.list_transformer.load_state_dict(torch.load(transformer_path))
249
+ except FileNotFoundError:
250
+ print(f"Warning: Could not load custom weights from {model_name_or_path}")
251
+
252
+ return model