lgcharpe commited on
Commit
24b42f1
·
verified ·
1 Parent(s): c28c53e

Uploading patch

Browse files
Files changed (1) hide show
  1. modeling_gpt_bert.py +9 -10
modeling_gpt_bert.py CHANGED
@@ -138,7 +138,7 @@ class Attention(nn.Module):
138
  - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
139
  position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
140
  position_indices = config.position_bucket_size - 1 + position_indices
141
- self.register_buffer("position_indices", position_indices, persistent=True)
142
 
143
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
144
  self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
@@ -301,18 +301,17 @@ class GPTBERT(GPTBERTPreTrainedModel):
301
  batch_size, seq_length = input_shape
302
 
303
  if attention_mask is None:
304
- attention_mask = input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(diagonal=1).unsqueeze(0).unsqueeze(0)
305
-
306
- if attention_mask is not None:
307
  attention_mask = ~attention_mask.bool()
308
 
309
- if len(attention_mask.size()) == 2:
310
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
311
- elif len(attention_mask.size()) == 3:
312
- attention_mask = attention_mask.unsqueeze(1)
313
 
314
- if self.is_causal:
315
- attention_mask = attention_mask | input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(1).unsqueeze(0).unsqueeze(0)
316
 
317
  static_embeddings, relative_embeddings = self.embedding(input_ids.t())
318
  contextualized_embeddings = [static_embeddings]
 
138
  - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
139
  position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
140
  position_indices = config.position_bucket_size - 1 + position_indices
141
+ self.register_buffer("position_indices", position_indices, persistent=False)
142
 
143
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
144
  self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
 
301
  batch_size, seq_length = input_shape
302
 
303
  if attention_mask is None:
304
+ attention_mask = input_ids.new_zeros((batch_size, seq_length), dtype=torch.bool).unsqueeze(1).unsqueeze(2)
305
+ else:
 
306
  attention_mask = ~attention_mask.bool()
307
 
308
+ if len(attention_mask.size()) == 2:
309
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
310
+ elif len(attention_mask.size()) == 3:
311
+ attention_mask = attention_mask.unsqueeze(1)
312
 
313
+ if self.is_causal:
314
+ attention_mask = attention_mask | input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(1).unsqueeze(0).unsqueeze(0)
315
 
316
  static_embeddings, relative_embeddings = self.embedding(input_ids.t())
317
  contextualized_embeddings = [static_embeddings]