lgcharpe commited on
Commit
5d4f6ac
·
verified ·
1 Parent(s): d98cea9

Uploading patch

Browse files
Files changed (1) hide show
  1. modeling_gpt_bert.py +6 -7
modeling_gpt_bert.py CHANGED
@@ -138,7 +138,7 @@ class Attention(nn.Module):
138
  - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
139
  position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
140
  position_indices = config.position_bucket_size - 1 + position_indices
141
- self.register_buffer("position_indices", position_indices, persistent=True)
142
 
143
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
144
  self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
@@ -302,14 +302,13 @@ class GPTBERT(GPTBERTPreTrainedModel):
302
 
303
  if attention_mask is None:
304
  attention_mask = input_ids.new_zeros((batch_size, seq_length), dtype=torch.bool).unsqueeze(1).unsqueeze(2)
305
-
306
- if attention_mask is not None:
307
  attention_mask = ~attention_mask.bool()
308
 
309
- if len(attention_mask.size()) == 2:
310
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
311
- elif len(attention_mask.size()) == 3:
312
- attention_mask = attention_mask.unsqueeze(1)
313
 
314
  if self.is_causal:
315
  attention_mask = attention_mask | input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(1).unsqueeze(0).unsqueeze(0)
 
138
  - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
139
  position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
140
  position_indices = config.position_bucket_size - 1 + position_indices
141
+ self.register_buffer("position_indices", position_indices, persistent=False)
142
 
143
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
144
  self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
 
302
 
303
  if attention_mask is None:
304
  attention_mask = input_ids.new_zeros((batch_size, seq_length), dtype=torch.bool).unsqueeze(1).unsqueeze(2)
305
+ else:
 
306
  attention_mask = ~attention_mask.bool()
307
 
308
+ if len(attention_mask.size()) == 2:
309
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
310
+ elif len(attention_mask.size()) == 3:
311
+ attention_mask = attention_mask.unsqueeze(1)
312
 
313
  if self.is_causal:
314
  attention_mask = attention_mask | input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(1).unsqueeze(0).unsqueeze(0)