Warn about megablocks more clearly and less often (#20)
Browse files- warn about megablocks more clearly and less often (8a4d4d9a7f96bf4ffe71c72251432824ebfd90d4)
Co-authored-by: Cebtenzzre <[email protected]>
- modeling_hf_nomic_bert.py +11 -6
modeling_hf_nomic_bert.py
CHANGED
@@ -3,13 +3,15 @@
|
|
3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
4 |
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
5 |
|
|
|
|
|
6 |
import collections
|
|
|
7 |
import logging
|
8 |
-
|
9 |
-
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
10 |
import math
|
11 |
import os
|
12 |
import re
|
|
|
13 |
from collections import OrderedDict
|
14 |
from functools import partial
|
15 |
from typing import List, Optional, Tuple, Union
|
@@ -54,8 +56,9 @@ try:
|
|
54 |
from megablocks.layers import dmoe
|
55 |
from megablocks.layers.arguments import Arguments
|
56 |
except ImportError:
|
57 |
-
logger.warning("!!!!!!!!!!!!megablocks not available, using torch.matmul instead")
|
58 |
dmoe = None
|
|
|
|
|
59 |
|
60 |
|
61 |
|
@@ -1612,7 +1615,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1612 |
)
|
1613 |
self.moe = moe
|
1614 |
if moe:
|
1615 |
-
if dmoe is not None:
|
1616 |
megablocks_args = Arguments(
|
1617 |
moe_num_experts=config.num_experts,
|
1618 |
moe_top_k=config.moe_top_k,
|
@@ -1628,6 +1631,8 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1628 |
)
|
1629 |
self.mlp = dmoe.dMoE(megablocks_args)
|
1630 |
else:
|
|
|
|
|
1631 |
self.mlp = NomicMoELayer(
|
1632 |
config
|
1633 |
)
|
@@ -1698,7 +1703,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1698 |
residual = (dropped + residual) if residual is not None else dropped
|
1699 |
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
1700 |
if self.moe:
|
1701 |
-
hidden_states = self.mlp(hidden_states, torch.where(attention_mask.squeeze() == 0, 1, 0))
|
1702 |
else:
|
1703 |
hidden_states = self.mlp(hidden_states)
|
1704 |
|
@@ -1715,7 +1720,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1715 |
)
|
1716 |
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
1717 |
if self.moe:
|
1718 |
-
mlp_out = self.mlp(hidden_states, torch.where(attention_mask.squeeze() == 0, 1, 0))
|
1719 |
else:
|
1720 |
mlp_out = self.mlp(hidden_states)
|
1721 |
|
|
|
3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
4 |
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
5 |
|
6 |
+
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
7 |
+
|
8 |
import collections
|
9 |
+
import inspect
|
10 |
import logging
|
|
|
|
|
11 |
import math
|
12 |
import os
|
13 |
import re
|
14 |
+
import warnings
|
15 |
from collections import OrderedDict
|
16 |
from functools import partial
|
17 |
from typing import List, Optional, Tuple, Union
|
|
|
56 |
from megablocks.layers import dmoe
|
57 |
from megablocks.layers.arguments import Arguments
|
58 |
except ImportError:
|
|
|
59 |
dmoe = None
|
60 |
+
else:
|
61 |
+
dmoe_is_nomic = 'attention_mask' in inspect.signature(dmoe.dMoE.forward).parameters
|
62 |
|
63 |
|
64 |
|
|
|
1615 |
)
|
1616 |
self.moe = moe
|
1617 |
if moe:
|
1618 |
+
if dmoe is not None and dmoe_is_nomic:
|
1619 |
megablocks_args = Arguments(
|
1620 |
moe_num_experts=config.num_experts,
|
1621 |
moe_top_k=config.moe_top_k,
|
|
|
1631 |
)
|
1632 |
self.mlp = dmoe.dMoE(megablocks_args)
|
1633 |
else:
|
1634 |
+
warnings.warn("Install Nomic's megablocks fork for better speed: " +
|
1635 |
+
"`pip install git+https://github.com/nomic-ai/megablocks.git`")
|
1636 |
self.mlp = NomicMoELayer(
|
1637 |
config
|
1638 |
)
|
|
|
1703 |
residual = (dropped + residual) if residual is not None else dropped
|
1704 |
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
1705 |
if self.moe:
|
1706 |
+
hidden_states = self.mlp(hidden_states, attention_mask=torch.where(attention_mask.squeeze() == 0, 1, 0))
|
1707 |
else:
|
1708 |
hidden_states = self.mlp(hidden_states)
|
1709 |
|
|
|
1720 |
)
|
1721 |
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
1722 |
if self.moe:
|
1723 |
+
mlp_out = self.mlp(hidden_states, attention_mask=torch.where(attention_mask.squeeze() == 0, 1, 0))
|
1724 |
else:
|
1725 |
mlp_out = self.mlp(hidden_states)
|
1726 |
|