Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn Cebtenzzre commited on
Commit
7710840
·
verified ·
1 Parent(s): e5042dc

Warn about megablocks more clearly and less often (#20)

Browse files

- warn about megablocks more clearly and less often (8a4d4d9a7f96bf4ffe71c72251432824ebfd90d4)


Co-authored-by: Cebtenzzre <[email protected]>

Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +11 -6
modeling_hf_nomic_bert.py CHANGED
@@ -3,13 +3,15 @@
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
  # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
 
 
6
  import collections
 
7
  import logging
8
-
9
- # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
10
  import math
11
  import os
12
  import re
 
13
  from collections import OrderedDict
14
  from functools import partial
15
  from typing import List, Optional, Tuple, Union
@@ -54,8 +56,9 @@ try:
54
  from megablocks.layers import dmoe
55
  from megablocks.layers.arguments import Arguments
56
  except ImportError:
57
- logger.warning("!!!!!!!!!!!!megablocks not available, using torch.matmul instead")
58
  dmoe = None
 
 
59
 
60
 
61
 
@@ -1612,7 +1615,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1612
  )
1613
  self.moe = moe
1614
  if moe:
1615
- if dmoe is not None:
1616
  megablocks_args = Arguments(
1617
  moe_num_experts=config.num_experts,
1618
  moe_top_k=config.moe_top_k,
@@ -1628,6 +1631,8 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1628
  )
1629
  self.mlp = dmoe.dMoE(megablocks_args)
1630
  else:
 
 
1631
  self.mlp = NomicMoELayer(
1632
  config
1633
  )
@@ -1698,7 +1703,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1698
  residual = (dropped + residual) if residual is not None else dropped
1699
  hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
1700
  if self.moe:
1701
- hidden_states = self.mlp(hidden_states, torch.where(attention_mask.squeeze() == 0, 1, 0))
1702
  else:
1703
  hidden_states = self.mlp(hidden_states)
1704
 
@@ -1715,7 +1720,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1715
  )
1716
  hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
1717
  if self.moe:
1718
- mlp_out = self.mlp(hidden_states, torch.where(attention_mask.squeeze() == 0, 1, 0))
1719
  else:
1720
  mlp_out = self.mlp(hidden_states)
1721
 
 
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
  # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
6
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
+
8
  import collections
9
+ import inspect
10
  import logging
 
 
11
  import math
12
  import os
13
  import re
14
+ import warnings
15
  from collections import OrderedDict
16
  from functools import partial
17
  from typing import List, Optional, Tuple, Union
 
56
  from megablocks.layers import dmoe
57
  from megablocks.layers.arguments import Arguments
58
  except ImportError:
 
59
  dmoe = None
60
+ else:
61
+ dmoe_is_nomic = 'attention_mask' in inspect.signature(dmoe.dMoE.forward).parameters
62
 
63
 
64
 
 
1615
  )
1616
  self.moe = moe
1617
  if moe:
1618
+ if dmoe is not None and dmoe_is_nomic:
1619
  megablocks_args = Arguments(
1620
  moe_num_experts=config.num_experts,
1621
  moe_top_k=config.moe_top_k,
 
1631
  )
1632
  self.mlp = dmoe.dMoE(megablocks_args)
1633
  else:
1634
+ warnings.warn("Install Nomic's megablocks fork for better speed: " +
1635
+ "`pip install git+https://github.com/nomic-ai/megablocks.git`")
1636
  self.mlp = NomicMoELayer(
1637
  config
1638
  )
 
1703
  residual = (dropped + residual) if residual is not None else dropped
1704
  hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
1705
  if self.moe:
1706
+ hidden_states = self.mlp(hidden_states, attention_mask=torch.where(attention_mask.squeeze() == 0, 1, 0))
1707
  else:
1708
  hidden_states = self.mlp(hidden_states)
1709
 
 
1720
  )
1721
  hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
1722
  if self.moe:
1723
+ mlp_out = self.mlp(hidden_states, attention_mask=torch.where(attention_mask.squeeze() == 0, 1, 0))
1724
  else:
1725
  mlp_out = self.mlp(hidden_states)
1726