Roman Solomatin
commited on
fix shapes
Browse files- listconranker.py +9 -31
listconranker.py
CHANGED
@@ -39,12 +39,11 @@ from typing import Union, List, Optional
|
|
39 |
class ListConRankerConfig(PretrainedConfig):
|
40 |
"""Configuration class for ListConRanker model."""
|
41 |
|
42 |
-
model_type = "
|
43 |
|
44 |
def __init__(
|
45 |
self,
|
46 |
list_transformer_layers: int = 2,
|
47 |
-
num_attention_heads: int = 8,
|
48 |
hidden_size: int = 1792,
|
49 |
base_hidden_size: int = 1024,
|
50 |
num_labels: int = 1,
|
@@ -52,12 +51,12 @@ class ListConRankerConfig(PretrainedConfig):
|
|
52 |
):
|
53 |
super().__init__(**kwargs)
|
54 |
self.list_transformer_layers = list_transformer_layers
|
55 |
-
self.num_attention_heads = num_attention_heads
|
56 |
self.hidden_size = hidden_size
|
57 |
self.base_hidden_size = base_hidden_size
|
58 |
self.num_labels = num_labels
|
59 |
|
60 |
self.bert_config = BertConfig(**kwargs)
|
|
|
61 |
self.bert_config.output_hidden_states = True
|
62 |
|
63 |
class QueryEmbedding(nn.Module):
|
@@ -85,7 +84,8 @@ class ListTransformer(nn.Module):
|
|
85 |
self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
86 |
self.linear_score1 = nn.Linear(config.hidden_size * 2, 1)
|
87 |
|
88 |
-
def forward(self, pair_features
|
|
|
89 |
pair_nums = [x + 1 for x in pair_nums]
|
90 |
batch_pair_features = pair_features.split(pair_nums)
|
91 |
|
@@ -154,7 +154,7 @@ class ListConRankerModel(PreTrainedModel):
|
|
154 |
super().__init__(config)
|
155 |
self.config = config
|
156 |
self.num_labels = config.num_labels
|
157 |
-
self.hf_model = BertModel(config)
|
158 |
|
159 |
self.sigmoid = nn.Sigmoid()
|
160 |
|
@@ -176,17 +176,8 @@ class ListConRankerModel(PreTrainedModel):
|
|
176 |
output_attentions: Optional[bool] = None,
|
177 |
output_hidden_states: Optional[bool] = None,
|
178 |
return_dict: Optional[bool] = None,
|
179 |
-
pair_num: Optional[torch.Tensor] = None,
|
180 |
**kwargs
|
181 |
) -> Union[SequenceClassifierOutput, tuple]:
|
182 |
-
# Handle pair_num parameter
|
183 |
-
if pair_num is not None:
|
184 |
-
pair_nums = pair_num.tolist()
|
185 |
-
else:
|
186 |
-
# Default behavior if pair_num is not provided
|
187 |
-
batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
|
188 |
-
pair_nums = [1] * batch_size
|
189 |
-
|
190 |
# Get device
|
191 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
192 |
self.list_transformer.device = device
|
@@ -195,20 +186,7 @@ class ListConRankerModel(PreTrainedModel):
|
|
195 |
if self.training:
|
196 |
pass
|
197 |
else:
|
198 |
-
|
199 |
-
if sum(pair_nums) > split_batch:
|
200 |
-
last_hidden_state_list = []
|
201 |
-
input_ids_list = input_ids.split(split_batch)
|
202 |
-
attention_mask_list = attention_mask.split(split_batch)
|
203 |
-
for i in range(len(input_ids_list)):
|
204 |
-
last_hidden_state = self.hf_model(
|
205 |
-
input_ids=input_ids_list[i],
|
206 |
-
attention_mask=attention_mask_list[i],
|
207 |
-
return_dict=True).hidden_states[-1]
|
208 |
-
last_hidden_state_list.append(last_hidden_state)
|
209 |
-
last_hidden_state = torch.cat(last_hidden_state_list, dim=0)
|
210 |
-
else:
|
211 |
-
ranker_out = self.hf_model(
|
212 |
input_ids=input_ids,
|
213 |
attention_mask=attention_mask,
|
214 |
token_type_ids=token_type_ids,
|
@@ -217,12 +195,12 @@ class ListConRankerModel(PreTrainedModel):
|
|
217 |
inputs_embeds=inputs_embeds,
|
218 |
output_attentions=output_attentions,
|
219 |
return_dict=True)
|
220 |
-
|
221 |
|
222 |
pair_features = self.average_pooling(last_hidden_state, attention_mask)
|
223 |
pair_features = self.linear_in_embedding(pair_features)
|
224 |
|
225 |
-
logits, pair_features_after_list_transformer = self.list_transformer(pair_features
|
226 |
logits = self.sigmoid(logits)
|
227 |
|
228 |
return logits
|
@@ -249,4 +227,4 @@ class ListConRankerModel(PreTrainedModel):
|
|
249 |
except FileNotFoundError:
|
250 |
print(f"Warning: Could not load custom weights from {model_name_or_path}")
|
251 |
|
252 |
-
return model
|
|
|
39 |
class ListConRankerConfig(PretrainedConfig):
|
40 |
"""Configuration class for ListConRanker model."""
|
41 |
|
42 |
+
model_type = "ListConRanker"
|
43 |
|
44 |
def __init__(
|
45 |
self,
|
46 |
list_transformer_layers: int = 2,
|
|
|
47 |
hidden_size: int = 1792,
|
48 |
base_hidden_size: int = 1024,
|
49 |
num_labels: int = 1,
|
|
|
51 |
):
|
52 |
super().__init__(**kwargs)
|
53 |
self.list_transformer_layers = list_transformer_layers
|
|
|
54 |
self.hidden_size = hidden_size
|
55 |
self.base_hidden_size = base_hidden_size
|
56 |
self.num_labels = num_labels
|
57 |
|
58 |
self.bert_config = BertConfig(**kwargs)
|
59 |
+
self.bert_config.hidden_size = self.base_hidden_size
|
60 |
self.bert_config.output_hidden_states = True
|
61 |
|
62 |
class QueryEmbedding(nn.Module):
|
|
|
84 |
self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
85 |
self.linear_score1 = nn.Linear(config.hidden_size * 2, 1)
|
86 |
|
87 |
+
def forward(self, pair_features: torch.Tensor):
|
88 |
+
pair_nums = pair_features.size(0)
|
89 |
pair_nums = [x + 1 for x in pair_nums]
|
90 |
batch_pair_features = pair_features.split(pair_nums)
|
91 |
|
|
|
154 |
super().__init__(config)
|
155 |
self.config = config
|
156 |
self.num_labels = config.num_labels
|
157 |
+
self.hf_model = BertModel(config.bert_config)
|
158 |
|
159 |
self.sigmoid = nn.Sigmoid()
|
160 |
|
|
|
176 |
output_attentions: Optional[bool] = None,
|
177 |
output_hidden_states: Optional[bool] = None,
|
178 |
return_dict: Optional[bool] = None,
|
|
|
179 |
**kwargs
|
180 |
) -> Union[SequenceClassifierOutput, tuple]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
# Get device
|
182 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
183 |
self.list_transformer.device = device
|
|
|
186 |
if self.training:
|
187 |
pass
|
188 |
else:
|
189 |
+
ranker_out = self.hf_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
input_ids=input_ids,
|
191 |
attention_mask=attention_mask,
|
192 |
token_type_ids=token_type_ids,
|
|
|
195 |
inputs_embeds=inputs_embeds,
|
196 |
output_attentions=output_attentions,
|
197 |
return_dict=True)
|
198 |
+
last_hidden_state = ranker_out.last_hidden_state
|
199 |
|
200 |
pair_features = self.average_pooling(last_hidden_state, attention_mask)
|
201 |
pair_features = self.linear_in_embedding(pair_features)
|
202 |
|
203 |
+
logits, pair_features_after_list_transformer = self.list_transformer(pair_features)
|
204 |
logits = self.sigmoid(logits)
|
205 |
|
206 |
return logits
|
|
|
227 |
except FileNotFoundError:
|
228 |
print(f"Warning: Could not load custom weights from {model_name_or_path}")
|
229 |
|
230 |
+
return model
|