Uploading patch
Browse files- modeling_gpt_bert.py +9 -10
modeling_gpt_bert.py
CHANGED
@@ -138,7 +138,7 @@ class Attention(nn.Module):
|
|
138 |
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
139 |
position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
140 |
position_indices = config.position_bucket_size - 1 + position_indices
|
141 |
-
self.register_buffer("position_indices", position_indices, persistent=
|
142 |
|
143 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
144 |
self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
|
@@ -301,18 +301,17 @@ class GPTBERT(GPTBERTPreTrainedModel):
|
|
301 |
batch_size, seq_length = input_shape
|
302 |
|
303 |
if attention_mask is None:
|
304 |
-
attention_mask = input_ids.
|
305 |
-
|
306 |
-
if attention_mask is not None:
|
307 |
attention_mask = ~attention_mask.bool()
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
|
314 |
-
|
315 |
-
|
316 |
|
317 |
static_embeddings, relative_embeddings = self.embedding(input_ids.t())
|
318 |
contextualized_embeddings = [static_embeddings]
|
|
|
138 |
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
139 |
position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
140 |
position_indices = config.position_bucket_size - 1 + position_indices
|
141 |
+
self.register_buffer("position_indices", position_indices, persistent=False)
|
142 |
|
143 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
144 |
self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
|
|
|
301 |
batch_size, seq_length = input_shape
|
302 |
|
303 |
if attention_mask is None:
|
304 |
+
attention_mask = input_ids.new_zeros((batch_size, seq_length), dtype=torch.bool).unsqueeze(1).unsqueeze(2)
|
305 |
+
else:
|
|
|
306 |
attention_mask = ~attention_mask.bool()
|
307 |
|
308 |
+
if len(attention_mask.size()) == 2:
|
309 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
310 |
+
elif len(attention_mask.size()) == 3:
|
311 |
+
attention_mask = attention_mask.unsqueeze(1)
|
312 |
|
313 |
+
if self.is_causal:
|
314 |
+
attention_mask = attention_mask | input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(1).unsqueeze(0).unsqueeze(0)
|
315 |
|
316 |
static_embeddings, relative_embeddings = self.embedding(input_ids.t())
|
317 |
contextualized_embeddings = [static_embeddings]
|