Roman Solomatin commited on
Commit
8c3030b
·
unverified ·
1 Parent(s): ea396ea

fix shapes

Browse files
Files changed (1) hide show
  1. listconranker.py +9 -31
listconranker.py CHANGED
@@ -39,12 +39,11 @@ from typing import Union, List, Optional
39
  class ListConRankerConfig(PretrainedConfig):
40
  """Configuration class for ListConRanker model."""
41
 
42
- model_type = "listconranker"
43
 
44
  def __init__(
45
  self,
46
  list_transformer_layers: int = 2,
47
- num_attention_heads: int = 8,
48
  hidden_size: int = 1792,
49
  base_hidden_size: int = 1024,
50
  num_labels: int = 1,
@@ -52,12 +51,12 @@ class ListConRankerConfig(PretrainedConfig):
52
  ):
53
  super().__init__(**kwargs)
54
  self.list_transformer_layers = list_transformer_layers
55
- self.num_attention_heads = num_attention_heads
56
  self.hidden_size = hidden_size
57
  self.base_hidden_size = base_hidden_size
58
  self.num_labels = num_labels
59
 
60
  self.bert_config = BertConfig(**kwargs)
 
61
  self.bert_config.output_hidden_states = True
62
 
63
  class QueryEmbedding(nn.Module):
@@ -85,7 +84,8 @@ class ListTransformer(nn.Module):
85
  self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size)
86
  self.linear_score1 = nn.Linear(config.hidden_size * 2, 1)
87
 
88
- def forward(self, pair_features, pair_nums):
 
89
  pair_nums = [x + 1 for x in pair_nums]
90
  batch_pair_features = pair_features.split(pair_nums)
91
 
@@ -154,7 +154,7 @@ class ListConRankerModel(PreTrainedModel):
154
  super().__init__(config)
155
  self.config = config
156
  self.num_labels = config.num_labels
157
- self.hf_model = BertModel(config)
158
 
159
  self.sigmoid = nn.Sigmoid()
160
 
@@ -176,17 +176,8 @@ class ListConRankerModel(PreTrainedModel):
176
  output_attentions: Optional[bool] = None,
177
  output_hidden_states: Optional[bool] = None,
178
  return_dict: Optional[bool] = None,
179
- pair_num: Optional[torch.Tensor] = None,
180
  **kwargs
181
  ) -> Union[SequenceClassifierOutput, tuple]:
182
- # Handle pair_num parameter
183
- if pair_num is not None:
184
- pair_nums = pair_num.tolist()
185
- else:
186
- # Default behavior if pair_num is not provided
187
- batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
188
- pair_nums = [1] * batch_size
189
-
190
  # Get device
191
  device = input_ids.device if input_ids is not None else inputs_embeds.device
192
  self.list_transformer.device = device
@@ -195,20 +186,7 @@ class ListConRankerModel(PreTrainedModel):
195
  if self.training:
196
  pass
197
  else:
198
- split_batch = 400
199
- if sum(pair_nums) > split_batch:
200
- last_hidden_state_list = []
201
- input_ids_list = input_ids.split(split_batch)
202
- attention_mask_list = attention_mask.split(split_batch)
203
- for i in range(len(input_ids_list)):
204
- last_hidden_state = self.hf_model(
205
- input_ids=input_ids_list[i],
206
- attention_mask=attention_mask_list[i],
207
- return_dict=True).hidden_states[-1]
208
- last_hidden_state_list.append(last_hidden_state)
209
- last_hidden_state = torch.cat(last_hidden_state_list, dim=0)
210
- else:
211
- ranker_out = self.hf_model(
212
  input_ids=input_ids,
213
  attention_mask=attention_mask,
214
  token_type_ids=token_type_ids,
@@ -217,12 +195,12 @@ class ListConRankerModel(PreTrainedModel):
217
  inputs_embeds=inputs_embeds,
218
  output_attentions=output_attentions,
219
  return_dict=True)
220
- last_hidden_state = ranker_out.last_hidden_state
221
 
222
  pair_features = self.average_pooling(last_hidden_state, attention_mask)
223
  pair_features = self.linear_in_embedding(pair_features)
224
 
225
- logits, pair_features_after_list_transformer = self.list_transformer(pair_features, pair_nums)
226
  logits = self.sigmoid(logits)
227
 
228
  return logits
@@ -249,4 +227,4 @@ class ListConRankerModel(PreTrainedModel):
249
  except FileNotFoundError:
250
  print(f"Warning: Could not load custom weights from {model_name_or_path}")
251
 
252
- return model
 
39
  class ListConRankerConfig(PretrainedConfig):
40
  """Configuration class for ListConRanker model."""
41
 
42
+ model_type = "ListConRanker"
43
 
44
  def __init__(
45
  self,
46
  list_transformer_layers: int = 2,
 
47
  hidden_size: int = 1792,
48
  base_hidden_size: int = 1024,
49
  num_labels: int = 1,
 
51
  ):
52
  super().__init__(**kwargs)
53
  self.list_transformer_layers = list_transformer_layers
 
54
  self.hidden_size = hidden_size
55
  self.base_hidden_size = base_hidden_size
56
  self.num_labels = num_labels
57
 
58
  self.bert_config = BertConfig(**kwargs)
59
+ self.bert_config.hidden_size = self.base_hidden_size
60
  self.bert_config.output_hidden_states = True
61
 
62
  class QueryEmbedding(nn.Module):
 
84
  self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size)
85
  self.linear_score1 = nn.Linear(config.hidden_size * 2, 1)
86
 
87
+ def forward(self, pair_features: torch.Tensor):
88
+ pair_nums = pair_features.size(0)
89
  pair_nums = [x + 1 for x in pair_nums]
90
  batch_pair_features = pair_features.split(pair_nums)
91
 
 
154
  super().__init__(config)
155
  self.config = config
156
  self.num_labels = config.num_labels
157
+ self.hf_model = BertModel(config.bert_config)
158
 
159
  self.sigmoid = nn.Sigmoid()
160
 
 
176
  output_attentions: Optional[bool] = None,
177
  output_hidden_states: Optional[bool] = None,
178
  return_dict: Optional[bool] = None,
 
179
  **kwargs
180
  ) -> Union[SequenceClassifierOutput, tuple]:
 
 
 
 
 
 
 
 
181
  # Get device
182
  device = input_ids.device if input_ids is not None else inputs_embeds.device
183
  self.list_transformer.device = device
 
186
  if self.training:
187
  pass
188
  else:
189
+ ranker_out = self.hf_model(
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  input_ids=input_ids,
191
  attention_mask=attention_mask,
192
  token_type_ids=token_type_ids,
 
195
  inputs_embeds=inputs_embeds,
196
  output_attentions=output_attentions,
197
  return_dict=True)
198
+ last_hidden_state = ranker_out.last_hidden_state
199
 
200
  pair_features = self.average_pooling(last_hidden_state, attention_mask)
201
  pair_features = self.linear_in_embedding(pair_features)
202
 
203
+ logits, pair_features_after_list_transformer = self.list_transformer(pair_features)
204
  logits = self.sigmoid(logits)
205
 
206
  return logits
 
227
  except FileNotFoundError:
228
  print(f"Warning: Could not load custom weights from {model_name_or_path}")
229
 
230
+ return model