Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- act_ckpt.py +123 -0
- added_tokens.json +28 -0
- attention.py +450 -0
- blocks.py +137 -0
- config.json +84 -0
- config_defaults.py +5 -0
- config_moe_args.py +159 -0
- configuration_mpt.py +252 -0
- custom_embedding.py +10 -0
- dmoe.py +138 -0
- fc.py +8 -0
- ffn.py +272 -0
- generation_config.json +5 -0
- layer_builders.py +33 -0
- layers_registry.py +22 -0
- merges.txt +0 -0
- model-00001-of-00007.safetensors +3 -0
- model-00002-of-00007.safetensors +3 -0
- model-00003-of-00007.safetensors +3 -0
- model-00004-of-00007.safetensors +3 -0
- model-00005-of-00007.safetensors +3 -0
- model-00006-of-00007.safetensors +3 -0
- model-00007-of-00007.safetensors +3 -0
- model.safetensors.index.json +289 -0
- modeling_mpt.py +696 -0
- mpt_param_count.py +130 -0
- norm.py +79 -0
- param_init_fns.py +448 -0
- registry_utils.py +131 -0
- special_tokens_map.json +31 -0
- tokenizer.json +3 -0
- tokenizer_config.json +240 -0
- vocab.json +0 -0
- warnings.py +72 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
act_ckpt.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
import torch
|
3 |
+
from .layers_registry import attention_classes, ffns, ffns_with_megablocks, ffns_with_norm, norms
|
4 |
+
from .blocks import FusedNormAttentionNorm, MPTBlock
|
5 |
+
|
6 |
+
def pass_on_block_idx(parent: torch.nn.Module):
|
7 |
+
if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'):
|
8 |
+
return
|
9 |
+
for child in parent.children():
|
10 |
+
child.block_idx = parent.block_idx
|
11 |
+
child.max_block_idx = parent.max_block_idx
|
12 |
+
if child.children():
|
13 |
+
pass_on_block_idx(child)
|
14 |
+
|
15 |
+
def get_act_ckpt_module(mod_name: str) -> Any:
|
16 |
+
"""Get the module type from the module name."""
|
17 |
+
if mod_name.lower() == 'mptblock':
|
18 |
+
mod_type = MPTBlock
|
19 |
+
elif mod_name in attention_classes:
|
20 |
+
mod_type = attention_classes.get(mod_name)
|
21 |
+
elif mod_name.lower() == 'norm_attn_norm':
|
22 |
+
mod_type = FusedNormAttentionNorm
|
23 |
+
elif mod_name in ffns:
|
24 |
+
mod_type = ffns.get(mod_name)
|
25 |
+
elif mod_name in ffns_with_norm:
|
26 |
+
mod_type = ffns_with_norm.get(mod_name)
|
27 |
+
elif mod_name in ffns_with_megablocks:
|
28 |
+
mod_type = ffns_with_megablocks.get(mod_name)
|
29 |
+
elif mod_name in norms:
|
30 |
+
mod_type = norms.get(mod_name)
|
31 |
+
else:
|
32 |
+
msg = ', '.join(list(attention_classes.get_all()) + list(ffns.get_all()) + list(ffns_with_norm.get_all()) + list(ffns_with_megablocks.get_all()) + list(norms.get_all()) + ['MPTBlock'])
|
33 |
+
raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.')
|
34 |
+
return mod_type
|
35 |
+
|
36 |
+
def parse_ele_str(ele: str, max_block_idx: int) -> list:
|
37 |
+
"""Parse a string in target_blocks and return a list of block ids to add.
|
38 |
+
|
39 |
+
Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
|
40 |
+
to the first n, the middle m, the last k, and the range [i, j).
|
41 |
+
"""
|
42 |
+
to_add = None
|
43 |
+
if ele.startswith('first-'):
|
44 |
+
assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
|
45 |
+
to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
|
46 |
+
elif ele.startswith('last-'):
|
47 |
+
assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
|
48 |
+
to_add = list(range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
|
49 |
+
elif ele.startswith('middle-'):
|
50 |
+
assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
|
51 |
+
num = int(ele[7:])
|
52 |
+
start = max(max_block_idx // 2 - num // 2, 0)
|
53 |
+
end = min(start + num, max_block_idx + 1)
|
54 |
+
to_add = list(range(start, end))
|
55 |
+
elif ele.startswith('range-'):
|
56 |
+
r = ele[6:].split('-')
|
57 |
+
assert len(r) == 2, f'Invalid target_blocks element {ele}'
|
58 |
+
start, end = (int(r[0]), int(r[1]))
|
59 |
+
start = max(start, 0)
|
60 |
+
end = min(end, max_block_idx + 1)
|
61 |
+
to_add = list(range(start, end))
|
62 |
+
else:
|
63 |
+
raise ValueError(f'Invalid target_blocks element {ele}')
|
64 |
+
return to_add
|
65 |
+
|
66 |
+
def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
|
67 |
+
"""Parse the user input and return a list of block ids."""
|
68 |
+
candidate_block_ids = []
|
69 |
+
if isinstance(target_blocks, int):
|
70 |
+
candidate_block_ids = list(range(target_blocks))
|
71 |
+
elif isinstance(target_blocks, list):
|
72 |
+
for ele in target_blocks:
|
73 |
+
if isinstance(ele, int):
|
74 |
+
candidate_block_ids.append(ele)
|
75 |
+
elif isinstance(ele, str):
|
76 |
+
to_add = parse_ele_str(ele, max_block_idx)
|
77 |
+
candidate_block_ids.extend(to_add)
|
78 |
+
else:
|
79 |
+
raise ValueError(f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}')
|
80 |
+
elif isinstance(target_blocks, str):
|
81 |
+
target_blocks = target_blocks.replace(' ', '')
|
82 |
+
for ele in target_blocks.split(','):
|
83 |
+
to_add = parse_ele_str(ele, max_block_idx)
|
84 |
+
candidate_block_ids.extend(to_add)
|
85 |
+
else:
|
86 |
+
raise ValueError(f'target_blocks must be either a single integer, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}')
|
87 |
+
candidate_block_ids = list(set(candidate_block_ids))
|
88 |
+
return candidate_block_ids
|
89 |
+
|
90 |
+
def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
|
91 |
+
"""Check if the block ids in the mapping overlap with each other."""
|
92 |
+
all_blocks = [None] * (max_block_idx + 1)
|
93 |
+
for k, v in mapping.items():
|
94 |
+
if v == -1:
|
95 |
+
v = list(range(max_block_idx + 1))
|
96 |
+
for vv in v:
|
97 |
+
if vv < 0 or vv > max_block_idx:
|
98 |
+
continue
|
99 |
+
elif all_blocks[vv] is not None:
|
100 |
+
raise ValueError(f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.')
|
101 |
+
else:
|
102 |
+
all_blocks[vv] = k
|
103 |
+
|
104 |
+
def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, max_block_idx: int) -> dict:
|
105 |
+
act_ckpt_mod_to_blocks = {}
|
106 |
+
if act_ckpt_target is None or act_ckpt_target == []:
|
107 |
+
mod = top_module
|
108 |
+
act_ckpt_mod_to_blocks[mod] = -1
|
109 |
+
elif isinstance(act_ckpt_target, str):
|
110 |
+
mod = get_act_ckpt_module(act_ckpt_target)
|
111 |
+
act_ckpt_mod_to_blocks[mod] = -1
|
112 |
+
elif isinstance(act_ckpt_target, list):
|
113 |
+
for target in act_ckpt_target:
|
114 |
+
mod = get_act_ckpt_module(target)
|
115 |
+
act_ckpt_mod_to_blocks[mod] = -1
|
116 |
+
elif isinstance(act_ckpt_target, dict):
|
117 |
+
for k, v in act_ckpt_target.items():
|
118 |
+
mod = get_act_ckpt_module(k)
|
119 |
+
block_ids = get_target_block_list(v, max_block_idx)
|
120 |
+
act_ckpt_mod_to_blocks[mod] = block_ids
|
121 |
+
else:
|
122 |
+
raise ValueError(f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}')
|
123 |
+
return act_ckpt_mod_to_blocks
|
added_tokens.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</think>": 151668,
|
3 |
+
"</tool_call>": 151658,
|
4 |
+
"</tool_response>": 151666,
|
5 |
+
"<think>": 151667,
|
6 |
+
"<tool_call>": 151657,
|
7 |
+
"<tool_response>": 151665,
|
8 |
+
"<|box_end|>": 151649,
|
9 |
+
"<|box_start|>": 151648,
|
10 |
+
"<|endoftext|>": 151643,
|
11 |
+
"<|file_sep|>": 151664,
|
12 |
+
"<|fim_middle|>": 151660,
|
13 |
+
"<|fim_pad|>": 151662,
|
14 |
+
"<|fim_prefix|>": 151659,
|
15 |
+
"<|fim_suffix|>": 151661,
|
16 |
+
"<|im_end|>": 151645,
|
17 |
+
"<|im_start|>": 151644,
|
18 |
+
"<|image_pad|>": 151655,
|
19 |
+
"<|object_ref_end|>": 151647,
|
20 |
+
"<|object_ref_start|>": 151646,
|
21 |
+
"<|quad_end|>": 151651,
|
22 |
+
"<|quad_start|>": 151650,
|
23 |
+
"<|repo_name|>": 151663,
|
24 |
+
"<|video_pad|>": 151656,
|
25 |
+
"<|vision_end|>": 151653,
|
26 |
+
"<|vision_pad|>": 151654,
|
27 |
+
"<|vision_start|>": 151652
|
28 |
+
}
|
attention.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Attention layers."""
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Optional
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
from einops import rearrange
|
9 |
+
from packaging import version
|
10 |
+
from torch import nn
|
11 |
+
from .layers_registry import attention_classes, attention_implementations
|
12 |
+
from .layer_builders import build_fc, build_norm
|
13 |
+
from .config_defaults import fc_type_defaults
|
14 |
+
|
15 |
+
def is_flash_v2_installed(v2_version: str='2.0.0'):
|
16 |
+
assert version.parse(v2_version) >= version.parse('2.0.0')
|
17 |
+
try:
|
18 |
+
import flash_attn as flash_attn
|
19 |
+
except:
|
20 |
+
return False
|
21 |
+
return version.parse(flash_attn.__version__) >= version.parse(v2_version)
|
22 |
+
|
23 |
+
def is_flash_v1_installed():
|
24 |
+
try:
|
25 |
+
import flash_attn as flash_attn
|
26 |
+
except:
|
27 |
+
return False
|
28 |
+
return version.parse(flash_attn.__version__) < version.parse('2.0.0')
|
29 |
+
|
30 |
+
def is_transformers_version_gte(hf_version: str) -> bool:
|
31 |
+
return version.parse(transformers.__version__) >= version.parse(hf_version)
|
32 |
+
|
33 |
+
def check_alibi_support(attention_impl: str) -> bool:
|
34 |
+
return attention_impl != 'flash' or is_flash_v2_installed(v2_version='v2.4.2')
|
35 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
36 |
+
|
37 |
+
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
|
38 |
+
if original_is_causal and num_query_tokens != num_key_tokens:
|
39 |
+
if num_query_tokens != 1:
|
40 |
+
raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
|
41 |
+
else:
|
42 |
+
return False
|
43 |
+
return original_is_causal
|
44 |
+
|
45 |
+
def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
|
46 |
+
"""Perform repeat of kv heads along a particular dimension.
|
47 |
+
|
48 |
+
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
|
49 |
+
n_rep: amount of repetitions of kv_n_heads
|
50 |
+
Unlike torch.repeat_interleave, this function avoids allocating new memory.
|
51 |
+
"""
|
52 |
+
if n_rep == 1:
|
53 |
+
return hidden
|
54 |
+
b, s, kv_n_heads, d = hidden.shape
|
55 |
+
hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
|
56 |
+
return hidden.reshape(b, s, kv_n_heads * n_rep, d)
|
57 |
+
|
58 |
+
def scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, attn_logit_softcapping: Optional[float]=None, sliding_window_size: int=-1) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
59 |
+
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
60 |
+
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
|
61 |
+
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
|
62 |
+
if past_key_value is not None:
|
63 |
+
if len(past_key_value) != 0:
|
64 |
+
k = torch.cat([past_key_value[0], k], dim=3)
|
65 |
+
v = torch.cat([past_key_value[1], v], dim=2)
|
66 |
+
past_key_value = (k, v)
|
67 |
+
b, _, s_q, d = q.shape
|
68 |
+
s_k = k.size(-1)
|
69 |
+
if kv_n_heads > 1 and kv_n_heads < n_heads:
|
70 |
+
k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
|
71 |
+
v = repeat_kv_for_gqa(v.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
|
72 |
+
if softmax_scale is None:
|
73 |
+
softmax_scale = 1 / math.sqrt(d)
|
74 |
+
attn_weight = q.matmul(k) * softmax_scale
|
75 |
+
if attn_logit_softcapping is not None:
|
76 |
+
attn_weight = attn_logit_softcapping * torch.tanh(attn_weight / attn_logit_softcapping)
|
77 |
+
if attn_bias is not None:
|
78 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
79 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
80 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
81 |
+
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
82 |
+
raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
|
83 |
+
attn_weight = attn_weight + attn_bias
|
84 |
+
min_val = torch.finfo(q.dtype).min
|
85 |
+
if key_padding_mask is not None:
|
86 |
+
if attn_bias is not None:
|
87 |
+
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
88 |
+
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
89 |
+
if is_causal and (not s_q == 1):
|
90 |
+
s = max(s_q, s_k)
|
91 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
|
92 |
+
causal_mask = causal_mask.tril()
|
93 |
+
causal_mask = causal_mask.to(torch.bool)
|
94 |
+
causal_mask = ~causal_mask
|
95 |
+
causal_mask = causal_mask[-s_q:, -s_k:]
|
96 |
+
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
97 |
+
if sliding_window_size != -1:
|
98 |
+
window_mask = torch.ones((s_q, s_k), dtype=torch.bool, device=attn_weight.device)
|
99 |
+
if not s_q == 1:
|
100 |
+
if s_q != s_k:
|
101 |
+
raise ValueError('Number of queries should be equal to the number of keys.')
|
102 |
+
window_mask = torch.tril(window_mask, diagonal=sliding_window_size)
|
103 |
+
window_mask = torch.triu(window_mask, diagonal=-sliding_window_size)
|
104 |
+
else:
|
105 |
+
window_mask[:, :-(sliding_window_size + 1)] = False
|
106 |
+
window_mask = ~window_mask
|
107 |
+
attn_weight = attn_weight.masked_fill(window_mask.view(1, 1, s_q, s_k), min_val)
|
108 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
109 |
+
if dropout_p:
|
110 |
+
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
111 |
+
out = attn_weight.to(v.dtype).matmul(v)
|
112 |
+
out = rearrange(out, 'b h s d -> b s (h d)')
|
113 |
+
if needs_weights:
|
114 |
+
return (out, attn_weight, past_key_value)
|
115 |
+
return (out, None, past_key_value)
|
116 |
+
|
117 |
+
def check_valid_inputs(*tensors: torch.Tensor, valid_dtypes: Optional[list[torch.dtype]]=None):
|
118 |
+
if valid_dtypes is None:
|
119 |
+
valid_dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
120 |
+
for tensor in tensors:
|
121 |
+
if tensor.dtype not in valid_dtypes:
|
122 |
+
raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
|
123 |
+
if not tensor.is_cuda:
|
124 |
+
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
125 |
+
|
126 |
+
def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, should_repeat_kv_for_gqa: Optional[bool]=True, sliding_window_size: int=-1, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None, attn_logit_softcapping: Optional[float]=None) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
127 |
+
if key_padding_mask is not None:
|
128 |
+
raise ValueError('key_padding_mask should be None for flash attn.')
|
129 |
+
del key_padding_mask
|
130 |
+
if flash_attn_padding_info is None:
|
131 |
+
raise ValueError('flash_attn_padding_info is required for flash attn.')
|
132 |
+
try:
|
133 |
+
from flash_attn import bert_padding, flash_attn_interface
|
134 |
+
except:
|
135 |
+
raise RuntimeError('Please install flash-attn==1.0.9 or flash-attn==2.3.6')
|
136 |
+
check_valid_inputs(query, key, value)
|
137 |
+
if past_key_value is not None:
|
138 |
+
if len(past_key_value) != 0:
|
139 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
140 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
141 |
+
past_key_value = (key, value)
|
142 |
+
if attn_bias is not None:
|
143 |
+
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
144 |
+
batch_size, seqlen = query.shape[:2]
|
145 |
+
indices_q = flash_attn_padding_info['indices_q'].to(query.device)
|
146 |
+
indices_k = flash_attn_padding_info['indices_k'].to(key.device)
|
147 |
+
indices_v = flash_attn_padding_info['indices_v'].to(value.device)
|
148 |
+
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'].to(query.device)
|
149 |
+
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'].to(key.device)
|
150 |
+
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
|
151 |
+
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']
|
152 |
+
query_unpad = bert_padding.index_first_axis(rearrange(query, 'b s ... -> (b s) ...'), indices_q)
|
153 |
+
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
|
154 |
+
key_unpad = bert_padding.index_first_axis(rearrange(key, 'b s ... -> (b s) ...'), indices_k)
|
155 |
+
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
|
156 |
+
value_unpad = bert_padding.index_first_axis(rearrange(value, 'b s ... -> (b s) ...'), indices_v)
|
157 |
+
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
|
158 |
+
if kv_n_heads < n_heads and (not is_flash_v2_installed()) and (not should_repeat_kv_for_gqa):
|
159 |
+
raise ValueError('For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.')
|
160 |
+
if should_repeat_kv_for_gqa:
|
161 |
+
if kv_n_heads == 1:
|
162 |
+
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
163 |
+
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
164 |
+
elif kv_n_heads < n_heads:
|
165 |
+
key_unpad = repeat_kv_for_gqa(key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
|
166 |
+
value_unpad = repeat_kv_for_gqa(value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)
|
167 |
+
dropout_p = dropout_p if training else 0.0
|
168 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
169 |
+
if is_flash_v1_installed():
|
170 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
171 |
+
elif is_flash_v2_installed():
|
172 |
+
extra_attn_kwargs = {}
|
173 |
+
if check_alibi_support('flash'):
|
174 |
+
extra_attn_kwargs['alibi_slopes'] = alibi_slopes
|
175 |
+
elif alibi_slopes is not None:
|
176 |
+
raise ValueError('alibi_slopes is only supported for flash-attn>=2.4.2')
|
177 |
+
if is_flash_v2_installed(v2_version='v2.6.2') and attn_logit_softcapping is not None:
|
178 |
+
extra_attn_kwargs['softcap'] = attn_logit_softcapping
|
179 |
+
output_unpad = flash_attn_interface.flash_attn_varlen_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, window_size=(sliding_window_size, sliding_window_size), **extra_attn_kwargs)
|
180 |
+
else:
|
181 |
+
raise RuntimeError('flash-attn==1.0.9 or flash-attn==2.4.2 is required.')
|
182 |
+
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
183 |
+
return (output, None, past_key_value)
|
184 |
+
|
185 |
+
@attention_classes.register_class('grouped_query_attention')
|
186 |
+
class GroupedQueryAttention(nn.Module):
|
187 |
+
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
|
188 |
+
|
189 |
+
and Multi-query attention (MQA).
|
190 |
+
|
191 |
+
This allows the user to set a variable of number of kv_n_heads, rather than
|
192 |
+
just n_heads or 1, as in MHA and MQA. Using torch attention implementation
|
193 |
+
enables user to also use additive bias. This class also supports
|
194 |
+
cross-attention with different `in_features` for key and value fc projections.
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, fused_qkv: bool=True, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', norm_eps: float=1e-05, fc_type: Optional[dict[str, Any]]=None, device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1, reuse_kv_layer_idx: Optional[int]=None, attn_logit_softcapping: Optional[float]=None, kv_dim: Optional[int]=None):
|
198 |
+
super().__init__()
|
199 |
+
self.attn_impl = attn_impl
|
200 |
+
self.clip_qkv = clip_qkv
|
201 |
+
self.qk_ln = qk_ln
|
202 |
+
self.qk_gn = qk_gn
|
203 |
+
self.fused_qkv = fused_qkv
|
204 |
+
self.d_model = d_model
|
205 |
+
self.n_heads = n_heads
|
206 |
+
self.kv_n_heads = kv_n_heads
|
207 |
+
self.sliding_window_size = sliding_window_size
|
208 |
+
self.reuse_kv_layer_idx = reuse_kv_layer_idx
|
209 |
+
self.attn_logit_softcapping = attn_logit_softcapping
|
210 |
+
self.kv_dim = kv_dim if kv_dim is not None else self.d_model
|
211 |
+
self.head_dim = d_model // n_heads
|
212 |
+
if fc_type is None:
|
213 |
+
fc_type = copy.deepcopy(fc_type_defaults)
|
214 |
+
fc_type['bias'] = bias
|
215 |
+
fc_type['device'] = device
|
216 |
+
fc_type_name = fc_type['name']
|
217 |
+
if self.kv_n_heads <= 0:
|
218 |
+
raise ValueError('kv_n_heads should be greater than zero.')
|
219 |
+
if self.kv_n_heads > self.n_heads:
|
220 |
+
raise ValueError('The number of KV heads should be less than or equal to Q heads.')
|
221 |
+
if self.n_heads % self.kv_n_heads != 0:
|
222 |
+
raise ValueError('Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.')
|
223 |
+
if qk_ln and qk_gn:
|
224 |
+
raise ValueError('Only one of qk_ln and qk_gn can be set to True.')
|
225 |
+
self.softmax_scale = softmax_scale
|
226 |
+
if self.softmax_scale is None:
|
227 |
+
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
228 |
+
self.attn_dropout_p = attn_pdrop
|
229 |
+
if self.reuse_kv_layer_idx is not None:
|
230 |
+
self.Wq = build_fc(name=fc_type_name, in_features=self.d_model, out_features=self.d_model, fc_kwargs=fc_type)
|
231 |
+
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
|
232 |
+
self.Wq._fused = (0, fuse_splits)
|
233 |
+
elif self.fused_qkv:
|
234 |
+
self.Wqkv = build_fc(name=fc_type_name, in_features=self.d_model, out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim, fc_kwargs=fc_type)
|
235 |
+
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
|
236 |
+
self.Wqkv._fused = (0, fuse_splits)
|
237 |
+
else:
|
238 |
+
self.Wq = build_fc(name=fc_type_name, in_features=self.d_model, out_features=self.d_model, fc_kwargs=fc_type)
|
239 |
+
self.Wk = build_fc(name=fc_type_name, in_features=self.kv_dim, out_features=self.kv_n_heads * self.head_dim, fc_kwargs=fc_type)
|
240 |
+
self.Wv = build_fc(name=fc_type_name, in_features=self.kv_dim, out_features=self.kv_n_heads * self.head_dim, fc_kwargs=fc_type)
|
241 |
+
q_fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)]
|
242 |
+
kv_fuse_splits = [i * self.head_dim for i in range(1, self.kv_n_heads)]
|
243 |
+
self.Wq._fused = (0, q_fuse_splits)
|
244 |
+
self.Wk._fused = (0, kv_fuse_splits)
|
245 |
+
self.Wv._fused = (0, kv_fuse_splits)
|
246 |
+
if self.qk_ln or self.qk_gn:
|
247 |
+
norm_size = self.head_dim if qk_gn else d_model
|
248 |
+
self.q_ln = build_norm(name=norm_type.lower(), normalized_shape=norm_size, eps=norm_eps, device=device)
|
249 |
+
if self.reuse_kv_layer_idx is None:
|
250 |
+
if qk_ln:
|
251 |
+
norm_size = self.head_dim * kv_n_heads
|
252 |
+
self.k_ln = build_norm(name=norm_type.lower(), normalized_shape=norm_size, eps=norm_eps, device=device)
|
253 |
+
self.attn_fn = attention_implementations.get(self.attn_impl)
|
254 |
+
self.out_proj = build_fc(name=fc_type_name, in_features=self.d_model, out_features=self.d_model, fc_kwargs=fc_type)
|
255 |
+
self.out_proj._is_residual = True
|
256 |
+
|
257 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[dict]=None, is_causal: bool=True, needs_weights: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, key_value_states: Optional[torch.Tensor]=None) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
258 |
+
extra_kwargs = {}
|
259 |
+
if prev_layer_key_value is not None:
|
260 |
+
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
|
261 |
+
query, key, value = self.get_qkv(x=x, key_value_states=key_value_states, **extra_kwargs)
|
262 |
+
if rotary_emb_w_meta_info is not None:
|
263 |
+
query, key, value = self._apply_rotary_embeddings(rotary_emb_w_meta_info, query, key, value)
|
264 |
+
extra_attn_kwargs = self.get_implementation_specific_args(attention_mask, alibi_slopes, flash_attn_padding_info)
|
265 |
+
context, attn_weights, past_key_value = self.attn_fn(query, key, value, n_heads=self.n_heads, kv_n_heads=self.kv_n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, attn_logit_softcapping=self.attn_logit_softcapping, sliding_window_size=self.sliding_window_size, **extra_attn_kwargs)
|
266 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
267 |
+
|
268 |
+
def get_qkv(self, x: torch.Tensor, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, key_value_states: Optional[torch.Tensor]=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
269 |
+
"""Computes and returns the query, key, and value tensors.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
x (torch.Tensor): The input query tensor.
|
273 |
+
prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer.
|
274 |
+
key_value_states (Optional[torch.Tensor]): The input tensor for keys and values.
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
query (torch.Tensor): The query tensor.
|
278 |
+
key (torch.Tensor): The key tensor.
|
279 |
+
value (torch.Tensor): The value tensor.
|
280 |
+
"""
|
281 |
+
if self.reuse_kv_layer_idx is not None:
|
282 |
+
if prev_layer_key_value is None:
|
283 |
+
raise ValueError('prev_layer_key_value is None, cannot reuse_prev_layer_kv.')
|
284 |
+
key, value = prev_layer_key_value
|
285 |
+
if self.attn_impl == 'torch':
|
286 |
+
key = rearrange(key, 'b h d s -> b s (h d)')
|
287 |
+
value = rearrange(value, 'b h s d -> b s (h d)')
|
288 |
+
query = self.Wq(x)
|
289 |
+
if self.clip_qkv:
|
290 |
+
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
291 |
+
if self.qk_ln or self.qk_gn:
|
292 |
+
q_shape = query.shape
|
293 |
+
if self.qk_gn:
|
294 |
+
b, s = query.shape[:2]
|
295 |
+
query = query.view(b, s, self.n_heads, -1)
|
296 |
+
dtype = query.dtype
|
297 |
+
query = self.q_ln(query).to(dtype).view(q_shape)
|
298 |
+
return (query, key, value)
|
299 |
+
if self.fused_qkv:
|
300 |
+
if key_value_states is not None:
|
301 |
+
raise ValueError('Cannot use separate hidden and key_value states when fused_qkv = True.')
|
302 |
+
qkv = self.Wqkv(x)
|
303 |
+
if self.clip_qkv:
|
304 |
+
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
305 |
+
query, key, value = qkv.split([self.d_model, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
|
306 |
+
else:
|
307 |
+
query = self.Wq(x)
|
308 |
+
if key_value_states is not None:
|
309 |
+
key = self.Wk(key_value_states)
|
310 |
+
value = self.Wv(key_value_states)
|
311 |
+
else:
|
312 |
+
key = self.Wk(x)
|
313 |
+
value = self.Wv(x)
|
314 |
+
if self.clip_qkv:
|
315 |
+
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
316 |
+
key = key.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
317 |
+
value = value.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
318 |
+
if self.qk_ln or self.qk_gn:
|
319 |
+
q_shape, k_shape = (query.shape, key.shape)
|
320 |
+
if self.qk_gn:
|
321 |
+
b, s = query.shape[:2]
|
322 |
+
query = query.view(b, s, self.n_heads, -1)
|
323 |
+
key = key.view(b, s, self.kv_n_heads, -1)
|
324 |
+
dtype = query.dtype
|
325 |
+
query = self.q_ln(query).to(dtype).view(q_shape)
|
326 |
+
key = self.k_ln(key).to(dtype).view(k_shape)
|
327 |
+
return (query, key, value)
|
328 |
+
|
329 |
+
def _apply_rotary_embeddings(self, rotary_emb_w_meta_info: dict[str, Any], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
330 |
+
if self.reuse_kv_layer_idx is not None:
|
331 |
+
orig_key, orig_value = (key, value)
|
332 |
+
key, value = (torch.empty_like(key), torch.empty_like(value))
|
333 |
+
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
|
334 |
+
seq_len = rotary_emb_w_meta_info['seq_len']
|
335 |
+
offset_info = rotary_emb_w_meta_info['offset_info']
|
336 |
+
bsz, seqlen = query.shape[:2]
|
337 |
+
query = query.view(bsz, seqlen, -1, self.head_dim)
|
338 |
+
key = key.view(bsz, seqlen, -1, self.head_dim)
|
339 |
+
if rotary_emb_w_meta_info['impl'] == 'dail':
|
340 |
+
value = value.view(bsz, seqlen, -1, self.head_dim)
|
341 |
+
kv = torch.stack([key, value], dim=2)
|
342 |
+
query, kv = rotary_emb(query, kv, seqlen_offset=offset_info, max_seqlen=seq_len)
|
343 |
+
[key, value] = torch.unbind(kv, dim=2)
|
344 |
+
value = value.view(bsz, seqlen, -1)
|
345 |
+
elif rotary_emb_w_meta_info['impl'] == 'hf':
|
346 |
+
if is_transformers_version_gte('4.38'):
|
347 |
+
cos, sin = rotary_emb(x=value, position_ids=offset_info)
|
348 |
+
else:
|
349 |
+
cos, sin = rotary_emb(x=value, seq_len=seq_len)
|
350 |
+
if is_transformers_version_gte('4.38'):
|
351 |
+
cos = cos.to(query.device)
|
352 |
+
sin = sin.to(query.device)
|
353 |
+
query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=None, unsqueeze_dim=2)
|
354 |
+
elif is_transformers_version_gte('4.36'):
|
355 |
+
query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=offset_info, unsqueeze_dim=2)
|
356 |
+
else:
|
357 |
+
query = query.transpose(1, 2)
|
358 |
+
key = key.transpose(1, 2)
|
359 |
+
query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=offset_info)
|
360 |
+
query = query.transpose(1, 2)
|
361 |
+
key = key.transpose(1, 2)
|
362 |
+
query = query.view(bsz, seqlen, -1)
|
363 |
+
key = key.view(bsz, seqlen, -1)
|
364 |
+
if self.reuse_kv_layer_idx is not None:
|
365 |
+
return (query, orig_key, orig_value)
|
366 |
+
return (query, key, value)
|
367 |
+
|
368 |
+
def get_implementation_specific_args(self, attention_mask: Optional[torch.Tensor]=None, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> dict[str, Any]:
|
369 |
+
"""Returns attention implementation specific args.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
attention_mask (Optional[torch.Tensor]): The attention mask.
|
373 |
+
alibi_slopes (Optional[torch.Tensor]): The alibi slopes.
|
374 |
+
flash_attn_padding_info (Optional[dict[str, torch.Tensor]]): The padding information, only required for flash attention.
|
375 |
+
|
376 |
+
Returns:
|
377 |
+
extra_attn_kwargs (dict[str, Any]): Implementation specific args.
|
378 |
+
"""
|
379 |
+
if self.attn_impl == 'flash':
|
380 |
+
extra_attn_kwargs = {'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info, 'key_padding_mask': None}
|
381 |
+
else:
|
382 |
+
extra_attn_kwargs = {'key_padding_mask': attention_mask}
|
383 |
+
return extra_attn_kwargs
|
384 |
+
|
385 |
+
@attention_classes.register_class('multihead_attention')
|
386 |
+
class MultiheadAttention(GroupedQueryAttention):
|
387 |
+
"""Multi-head self attention.
|
388 |
+
|
389 |
+
Using torch attention implementation enables user to also use additive bias.
|
390 |
+
"""
|
391 |
+
|
392 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, fused_qkv: bool=True, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', norm_eps: float=1e-05, fc_type: Optional[dict[str, Any]]=None, device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1, reuse_kv_layer_idx: Optional[int]=None, attn_logit_softcapping: Optional[float]=None, kv_dim: Optional[int]=None):
|
393 |
+
super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size, reuse_kv_layer_idx=reuse_kv_layer_idx, attn_logit_softcapping=attn_logit_softcapping, kv_dim=kv_dim)
|
394 |
+
|
395 |
+
@attention_classes.register_class('multiquery_attention')
|
396 |
+
class MultiQueryAttention(GroupedQueryAttention):
|
397 |
+
"""Multi-Query self attention.
|
398 |
+
|
399 |
+
Using torch attention implementation enables user to also use additive bias.
|
400 |
+
"""
|
401 |
+
|
402 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, fused_qkv: bool=True, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', norm_eps: float=1e-05, fc_type: Optional[dict[str, Any]]=None, device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1, reuse_kv_layer_idx: Optional[int]=None, attn_logit_softcapping: Optional[float]=None, kv_dim: Optional[int]=None):
|
403 |
+
super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size, reuse_kv_layer_idx=reuse_kv_layer_idx, attn_logit_softcapping=attn_logit_softcapping, kv_dim=kv_dim)
|
404 |
+
|
405 |
+
def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, causal: bool, use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
|
406 |
+
if attn_impl == 'flash':
|
407 |
+
return None
|
408 |
+
elif attn_impl == 'torch':
|
409 |
+
if alibi:
|
410 |
+
if not causal or use_sequence_id:
|
411 |
+
return (1, n_heads, seq_len, seq_len)
|
412 |
+
return (1, n_heads, 1, seq_len)
|
413 |
+
elif use_sequence_id:
|
414 |
+
return (1, 1, seq_len, seq_len)
|
415 |
+
return None
|
416 |
+
else:
|
417 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
418 |
+
|
419 |
+
def build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool=False, alibi: bool=False, alibi_bias_max: int=8) -> Optional[torch.Tensor]:
|
420 |
+
if attn_impl == 'flash':
|
421 |
+
return None
|
422 |
+
elif attn_impl == 'torch':
|
423 |
+
if alibi:
|
424 |
+
device, dtype = (attn_bias.device, attn_bias.dtype)
|
425 |
+
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
426 |
+
return attn_bias
|
427 |
+
else:
|
428 |
+
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
429 |
+
|
430 |
+
def gen_slopes(n_heads: int, alibi_bias_max: int=8, device: Optional[torch.device]=None, return_1d: bool=False) -> torch.Tensor:
|
431 |
+
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
432 |
+
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
433 |
+
m = m.mul(alibi_bias_max / _n_heads)
|
434 |
+
slopes = 1.0 / torch.pow(2, m)
|
435 |
+
if _n_heads != n_heads:
|
436 |
+
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
437 |
+
if return_1d:
|
438 |
+
return slopes
|
439 |
+
return slopes.view(1, n_heads, 1, 1)
|
440 |
+
|
441 |
+
def build_alibi_bias(n_heads: int, seq_len: int, full: bool=False, alibi_bias_max: int=8, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None) -> torch.Tensor:
|
442 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
|
443 |
+
if full:
|
444 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
|
445 |
+
alibi_bias = alibi_bias.abs().mul(-1)
|
446 |
+
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
447 |
+
alibi_bias = alibi_bias * slopes
|
448 |
+
return alibi_bias.to(dtype=dtype)
|
449 |
+
attention_implementations.register('flash', func=flash_attn_fn)
|
450 |
+
attention_implementations.register('torch', func=scaled_multihead_dot_product_attention)
|
blocks.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GPT Blocks used for the GPT Model."""
|
2 |
+
import copy
|
3 |
+
from typing import Any, Optional
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from .layers_registry import ffns_with_norm
|
7 |
+
from .layer_builders import build_attention_layer, build_ffn, build_norm
|
8 |
+
from .config_defaults import attn_config_defaults, fc_type_defaults
|
9 |
+
try:
|
10 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
11 |
+
except:
|
12 |
+
unpad_input, pad_input = (None, None)
|
13 |
+
|
14 |
+
class MPTBlock(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[dict]=None, ffn_config: Optional[dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', norm_eps: float=1e-05, fc_type: Optional[dict[str, Any]]=None, device: Optional[str]=None, no_bias: bool=False, use_pad_tok_in_ffn: bool=True, **kwargs: Any):
|
17 |
+
if attn_config is None:
|
18 |
+
attn_config = attn_config_defaults
|
19 |
+
if ffn_config is None:
|
20 |
+
self.ffn_config: dict[str, Any] = {'ffn_type': 'mptmlp'}
|
21 |
+
else:
|
22 |
+
self.ffn_config = ffn_config
|
23 |
+
if fc_type is None:
|
24 |
+
fc_type = copy.deepcopy(fc_type_defaults)
|
25 |
+
fc_type['bias'] = not no_bias
|
26 |
+
fc_type['device'] = device
|
27 |
+
self.ffn_config['fc_type'] = fc_type
|
28 |
+
self.fuse_norm_attn_norm = kwargs.get('fuse_norm_attn_norm', False)
|
29 |
+
del kwargs
|
30 |
+
super().__init__()
|
31 |
+
ffn_type = self.ffn_config['ffn_type']
|
32 |
+
ffn_has_norm = ffn_type in ffns_with_norm
|
33 |
+
if self.fuse_norm_attn_norm:
|
34 |
+
self.norm_attn_norm = FusedNormAttentionNorm(d_model=d_model, n_heads=n_heads, args_to_exclude_in_attn_class=self.args_to_exclude_in_attn_class, attn_config=attn_config, ffn_has_norm=ffn_has_norm, fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, norm_eps=norm_eps, device=device, no_bias=no_bias)
|
35 |
+
else:
|
36 |
+
assert isinstance(attn_config['attn_type'], str)
|
37 |
+
attn_config_subset_for_attn_class = {k: v for k, v in attn_config.items() if k not in self.args_to_exclude_in_attn_class}
|
38 |
+
self.norm_1 = build_norm(name=norm_type.lower(), normalized_shape=d_model, eps=norm_eps, device=device)
|
39 |
+
self.attn = build_attention_layer(name=attn_config['attn_type'], attn_kwargs={'d_model': d_model, 'n_heads': n_heads, 'fc_type': fc_type, 'device': device, 'bias': not no_bias, **attn_config_subset_for_attn_class})
|
40 |
+
self.norm_2 = None
|
41 |
+
if not ffn_has_norm:
|
42 |
+
self.norm_2 = build_norm(name=norm_type.lower(), normalized_shape=d_model, eps=norm_eps, device=device)
|
43 |
+
self.ffn = build_ffn(name=ffn_type, d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, ffn_kwargs=self.ffn_config)
|
44 |
+
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
45 |
+
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
46 |
+
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
|
47 |
+
|
48 |
+
@property
|
49 |
+
def args_to_exclude_in_attn_class(self):
|
50 |
+
return {'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', 'rope_hf_config'}
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[dict]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, key_value_states: Optional[torch.Tensor]=None) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
53 |
+
extra_kwargs = {}
|
54 |
+
if prev_layer_key_value is not None:
|
55 |
+
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
|
56 |
+
if key_value_states is not None:
|
57 |
+
extra_kwargs['key_value_states'] = key_value_states
|
58 |
+
if self.fuse_norm_attn_norm:
|
59 |
+
x, m, attn_weights, past_key_value = self.norm_attn_norm(x, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, output_attentions=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, **extra_kwargs)
|
60 |
+
else:
|
61 |
+
a = self.norm_1(x)
|
62 |
+
b, attn_weights, past_key_value = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, **extra_kwargs)
|
63 |
+
x = x + self.resid_attn_dropout(b)
|
64 |
+
m = x
|
65 |
+
if self.norm_2 is not None:
|
66 |
+
m = self.norm_2(x)
|
67 |
+
n = self.apply_ffn(attention_mask, m)
|
68 |
+
x = x.to(device=n.device) + self.resid_ffn_dropout(n).to(device=n.device)
|
69 |
+
return (x, attn_weights, past_key_value)
|
70 |
+
|
71 |
+
def apply_ffn(self, attention_mask: Optional[torch.ByteTensor], m: torch.Tensor) -> torch.Tensor:
|
72 |
+
"""Apply feed forward layers to the input.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
attention_mask (Optional[torch.ByteTensor]): The attention mask.
|
76 |
+
m (torch.Tensor): The input.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
n (torch.Tensor): The output.
|
80 |
+
"""
|
81 |
+
batch_size, seq_len = m.size()[:2]
|
82 |
+
indices = None
|
83 |
+
if not self.use_pad_tok_in_ffn and attention_mask is not None:
|
84 |
+
assert unpad_input is not None
|
85 |
+
attention_mask = self.slice_attention_mask(attention_mask, seq_len)
|
86 |
+
m, indices, *_ = unpad_input(m, attention_mask)
|
87 |
+
n = self.ffn(m)
|
88 |
+
if not self.use_pad_tok_in_ffn and attention_mask is not None:
|
89 |
+
assert pad_input is not None
|
90 |
+
n = pad_input(n, indices, batch_size, seq_len)
|
91 |
+
return n
|
92 |
+
|
93 |
+
def slice_attention_mask(self, attention_mask: torch.ByteTensor, seq_len: int) -> torch.ByteTensor:
|
94 |
+
"""Slice attention mask to the correct size.
|
95 |
+
|
96 |
+
Can be overridden by subclasses to apply different slicing logic.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
attention_mask (torch.ByteTensor): The attention mask.
|
100 |
+
seq_len (int): The sequence length.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
torch.ByteTensor: The sliced attention mask.
|
104 |
+
"""
|
105 |
+
return attention_mask
|
106 |
+
|
107 |
+
class FusedNormAttentionNorm(nn.Module):
|
108 |
+
|
109 |
+
def __init__(self, d_model: int, n_heads: int, args_to_exclude_in_attn_class: set[str], attn_config: Optional[dict]=None, ffn_has_norm: bool=False, fc_type: Optional[dict[str, Any]]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', norm_eps: float=1e-05, device: Optional[str]=None, no_bias: bool=False, **kwargs: Any):
|
110 |
+
super().__init__()
|
111 |
+
assert attn_config is not None
|
112 |
+
assert isinstance(attn_config['attn_type'], str)
|
113 |
+
if fc_type is None:
|
114 |
+
fc_type = copy.deepcopy(fc_type_defaults)
|
115 |
+
fc_type['bias'] = not no_bias
|
116 |
+
fc_type['device'] = device
|
117 |
+
attn_config_subset_for_attn_class = {k: v for k, v in attn_config.items() if k not in args_to_exclude_in_attn_class}
|
118 |
+
self.norm_1 = build_norm(name=norm_type.lower(), normalized_shape=d_model, eps=norm_eps, device=device)
|
119 |
+
self.attn = build_attention_layer(name=attn_config['attn_type'], attn_kwargs={'d_model': d_model, 'n_heads': n_heads, 'fc_type': fc_type, 'device': device, 'bias': not no_bias, **attn_config_subset_for_attn_class})
|
120 |
+
self.norm_2 = None
|
121 |
+
if not ffn_has_norm:
|
122 |
+
self.norm_2 = build_norm(name=norm_type.lower(), normalized_shape=d_model, eps=norm_eps, device=device)
|
123 |
+
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
124 |
+
|
125 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[dict]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, key_value_states: Optional[torch.Tensor]=None) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
126 |
+
a = self.norm_1(x)
|
127 |
+
extra_kwargs = {}
|
128 |
+
if prev_layer_key_value is not None:
|
129 |
+
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
|
130 |
+
if key_value_states is not None:
|
131 |
+
extra_kwargs['key_value_states'] = key_value_states
|
132 |
+
b, attn_weights, past_key_value = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, **extra_kwargs)
|
133 |
+
x = x + self.resid_attn_dropout(b)
|
134 |
+
m = x
|
135 |
+
if self.norm_2 is not None:
|
136 |
+
m = self.norm_2(x)
|
137 |
+
return (x, m, attn_weights, past_key_value)
|
config.json
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MPTForCausalLM"
|
4 |
+
],
|
5 |
+
"attn_config": {
|
6 |
+
"alibi": false,
|
7 |
+
"alibi_bias_max": 8,
|
8 |
+
"attn_impl": "torch",
|
9 |
+
"attn_logit_softcapping": null,
|
10 |
+
"attn_pdrop": 0.0,
|
11 |
+
"attn_type": "grouped_query_attention",
|
12 |
+
"attn_uses_sequence_id": true,
|
13 |
+
"clip_qkv": null,
|
14 |
+
"fused_qkv": true,
|
15 |
+
"kv_dim": null,
|
16 |
+
"kv_n_heads": 8,
|
17 |
+
"qk_gn": false,
|
18 |
+
"qk_ln": false,
|
19 |
+
"rope": true,
|
20 |
+
"rope_dail_config": {
|
21 |
+
"pos_idx_in_fp32": true,
|
22 |
+
"type": "original",
|
23 |
+
"xpos_scale_base": 512
|
24 |
+
},
|
25 |
+
"rope_hf_config": {
|
26 |
+
"factor": 1.0,
|
27 |
+
"type": "no_scaling"
|
28 |
+
},
|
29 |
+
"rope_impl": "dail",
|
30 |
+
"rope_theta": 50000,
|
31 |
+
"sliding_window_size": -1,
|
32 |
+
"softmax_scale": null
|
33 |
+
},
|
34 |
+
"auto_map": {
|
35 |
+
"AutoConfig": "configuration_mpt.MPTConfig",
|
36 |
+
"AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM"
|
37 |
+
},
|
38 |
+
"block_overrides": null,
|
39 |
+
"d_model": 3072,
|
40 |
+
"emb_pdrop": 0.0,
|
41 |
+
"embedding_fraction": 1.0,
|
42 |
+
"expansion_ratio": 4,
|
43 |
+
"fc_type": {
|
44 |
+
"name": "torch"
|
45 |
+
},
|
46 |
+
"ffn_config": {
|
47 |
+
"fc_type": {
|
48 |
+
"name": "torch"
|
49 |
+
},
|
50 |
+
"ffn_act_fn": {
|
51 |
+
"name": "silu"
|
52 |
+
},
|
53 |
+
"ffn_type": "mptmlp"
|
54 |
+
},
|
55 |
+
"ffn_hidden_size": 8192,
|
56 |
+
"final_logit_softcapping": null,
|
57 |
+
"init_config": {
|
58 |
+
"emb_init_std": null,
|
59 |
+
"emb_init_uniform_lim": null,
|
60 |
+
"fan_mode": "fan_in",
|
61 |
+
"init_div_is_residual": true,
|
62 |
+
"init_gain": 0.0,
|
63 |
+
"init_nonlinearity": "relu",
|
64 |
+
"init_std": null,
|
65 |
+
"name": "kaiming_normal_"
|
66 |
+
},
|
67 |
+
"init_device": "cpu",
|
68 |
+
"layer_norm_epsilon": 1e-05,
|
69 |
+
"learned_pos_emb": false,
|
70 |
+
"logit_scale": null,
|
71 |
+
"max_seq_len": 4096,
|
72 |
+
"model_type": "mpt",
|
73 |
+
"n_heads": 24,
|
74 |
+
"n_layers": 28,
|
75 |
+
"no_bias": false,
|
76 |
+
"norm_eps": 1e-05,
|
77 |
+
"norm_type": "rmsnorm",
|
78 |
+
"resid_pdrop": 0.0,
|
79 |
+
"torch_dtype": "bfloat16",
|
80 |
+
"transformers_version": "4.51.3",
|
81 |
+
"use_cache": false,
|
82 |
+
"use_pad_tok_in_ffn": true,
|
83 |
+
"vocab_size": 152000
|
84 |
+
}
|
config_defaults.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Defaults for MPT model component configs."""
|
2 |
+
ffn_config_defaults: dict = {'ffn_type': 'mptmlp'}
|
3 |
+
attn_config_defaults: dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'flash', 'qk_ln': False, 'qk_gn': False, 'fused_qkv': True, 'clip_qkv': None, 'softmax_scale': None, 'attn_uses_sequence_id': False, 'sliding_window_size': -1, 'attn_logit_softcapping': None, 'alibi': False, 'alibi_bias_max': 8, 'rope': False, 'rope_theta': 10000, 'rope_impl': 'dail', 'rope_dail_config': {'type': 'original', 'pos_idx_in_fp32': True, 'xpos_scale_base': 512}, 'rope_hf_config': {'type': 'no_scaling', 'factor': 1.0}, 'kv_dim': None}
|
4 |
+
init_config_defaults: dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
|
5 |
+
fc_type_defaults: dict = {'name': 'torch'}
|
config_moe_args.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper function to configure MPT with MoEs."""
|
2 |
+
import inspect
|
3 |
+
from typing import Callable, Optional, Union
|
4 |
+
import torch
|
5 |
+
from packaging import version
|
6 |
+
from torch import distributed
|
7 |
+
from torch.distributed._tensor import DeviceMesh
|
8 |
+
from .layers_registry import ffns_with_megablocks
|
9 |
+
from .ffn import resolve_ffn_hidden_size
|
10 |
+
|
11 |
+
def create_process_group_ranks(ranks: tuple[int, ...]):
|
12 |
+
"""Creates a new distributed group.
|
13 |
+
|
14 |
+
Used in create_set_process_group and create_mod_process_group methods below.
|
15 |
+
|
16 |
+
This function is an alternative to `distributed.new_group(ranks)`.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
ranks (tuple[int, ...]): Tuple of ranks of group members.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
A handle of distributed group that can be given to collective calls.
|
23 |
+
"""
|
24 |
+
ranks_gather_list = [None for _ in range(distributed.get_world_size())]
|
25 |
+
distributed.all_gather_object(ranks_gather_list, ranks)
|
26 |
+
ranks_per_subgroup = list(set(ranks_gather_list))
|
27 |
+
group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration(ranks_per_subgroup)
|
28 |
+
return group
|
29 |
+
|
30 |
+
def create_set_process_group(k: int):
|
31 |
+
"""Creates a new distributed group using sets of k GPUs.
|
32 |
+
|
33 |
+
For example, if you have 16 GPUs and input k=4, the resulting process groups
|
34 |
+
will have ranks:
|
35 |
+
process group 0 ranks: [ 0, 1, 2, 3]
|
36 |
+
process group 1 ranks: [ 4, 5, 6, 7]
|
37 |
+
process group 2 ranks: [ 8, 9, 10, 11]
|
38 |
+
process group 3 ranks: [12, 13, 14, 15]
|
39 |
+
|
40 |
+
Args:
|
41 |
+
k (int): Number of GPUs to use in set size.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
A handle of distributed group that can be given to collective calls.
|
45 |
+
"""
|
46 |
+
world_size = distributed.get_world_size()
|
47 |
+
if world_size % k != 0:
|
48 |
+
raise RuntimeError(f'world_size={world_size!r} must be divisible by k={k!r}.')
|
49 |
+
start = distributed.get_rank() // k * k
|
50 |
+
ranks = tuple(range(start, start + k))
|
51 |
+
return create_process_group_ranks(ranks)
|
52 |
+
|
53 |
+
def get_megablocks_device_mesh(device_mesh_cfg: Optional[tuple[int, ...]], moe_world_size: int, world_size: int) -> DeviceMesh:
|
54 |
+
"""Helper function to get the device mesh for MegaBlocks MoE.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
device_mesh_cfg (Optional[tuple[int, ...]]): The device mesh configuration specification.
|
58 |
+
moe_world_size (int): The MoE world size.
|
59 |
+
world_size (int): The world size.
|
60 |
+
|
61 |
+
Raises:
|
62 |
+
ValueError: If the device mesh configuration is not valid.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
The device mesh for MegaBlocks MoE.
|
66 |
+
"""
|
67 |
+
from torch.distributed._tensor.device_mesh import init_device_mesh
|
68 |
+
if device_mesh_cfg is None or len(device_mesh_cfg) == 1:
|
69 |
+
if device_mesh_cfg is not None:
|
70 |
+
world_size = device_mesh_cfg[0]
|
71 |
+
sharding_group_dim = world_size // moe_world_size
|
72 |
+
device_mesh = init_device_mesh('cuda', (sharding_group_dim, moe_world_size), mesh_dim_names=('weight_parallel', 'expert_parallel'))
|
73 |
+
else:
|
74 |
+
raise ValueError(f'device_mesh_cfg={device_mesh_cfg!r} must be length 1')
|
75 |
+
return device_mesh
|
76 |
+
|
77 |
+
def config_megablocks_moe_args(ffn_config: dict, d_model: int, expansion_ratio: Union[int, float], n_layers: int, get_device_mesh: Callable) -> dict:
|
78 |
+
"""Configures `ffn_config` for MegaBlocks MoE.
|
79 |
+
|
80 |
+
We prepare all necessary arguments for `megablocks.layers.arguments.Arguments` so that process
|
81 |
+
groups can be initialized and shared across all blocks in the network.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
ffn_config (dict): FFN configuration before the MegaBlocks MoE is configured.
|
85 |
+
d_model (int): Hidden size of the network.
|
86 |
+
expansion_ratio (Union[int, float]): Expansion ratio in FFN.
|
87 |
+
n_layers (int): Number of blocks used in the network.
|
88 |
+
get_device_mesh (Callable): Function to get the device mesh. Takes in the device mesh config and the MoE world size.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
ffn_config (dict): FFN configuration with MegaBlocks MoE configured.
|
92 |
+
"""
|
93 |
+
try:
|
94 |
+
import megablocks
|
95 |
+
except:
|
96 |
+
raise RuntimeError('Requirements for MegaBlocks not installed; see install instructions in `README.md`.')
|
97 |
+
ffn_config.setdefault('fp16', False)
|
98 |
+
ffn_config.setdefault('bf16', False)
|
99 |
+
ffn_config['num_layers'] = n_layers
|
100 |
+
ffn_type = ffn_config.pop('ffn_type')
|
101 |
+
fc_type = ffn_config.pop('fc_type')
|
102 |
+
ffn_act_fn = ffn_config.pop('ffn_act_fn', None)
|
103 |
+
world_size = 1
|
104 |
+
moe_world_size = ffn_config.pop('moe_world_size')
|
105 |
+
device_mesh = None
|
106 |
+
device_mesh_cfg = ffn_config.pop('device_mesh', None)
|
107 |
+
if moe_world_size > 1:
|
108 |
+
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.2.0'):
|
109 |
+
raise RuntimeError('MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.')
|
110 |
+
world_size = distributed.get_world_size()
|
111 |
+
if world_size < moe_world_size or world_size % moe_world_size:
|
112 |
+
raise ValueError(f'Invalid world size configuration: world_size={world_size!r} and moe_world_size={moe_world_size!r}')
|
113 |
+
device_mesh = get_device_mesh(device_mesh_cfg=device_mesh_cfg, moe_world_size=moe_world_size, world_size=world_size)
|
114 |
+
ffn_config['moe_expert_model_parallelism'] = True
|
115 |
+
ffn_config['expert_parallel_group'] = device_mesh['expert_parallel'].get_group(0)
|
116 |
+
lbl_process_group = ffn_config.get('lbl_process_group', None)
|
117 |
+
if lbl_process_group is not None:
|
118 |
+
if lbl_process_group == 'expert_group':
|
119 |
+
lbl_process_group = ffn_config['expert_parallel_group']
|
120 |
+
elif lbl_process_group == 'global_group':
|
121 |
+
lbl_process_group = distributed.group.WORLD
|
122 |
+
elif isinstance(lbl_process_group, int):
|
123 |
+
if lbl_process_group > 1:
|
124 |
+
lbl_process_group = create_set_process_group(lbl_process_group)
|
125 |
+
else:
|
126 |
+
lbl_process_group = None
|
127 |
+
elif not isinstance(lbl_process_group, distributed.ProcessGroup):
|
128 |
+
raise ValueError(f'Unknown lbl_process_group={lbl_process_group!r}. Options are: none | a process group | ``expert_group`` | ``global_group`` | <GROUP_SIZE>.')
|
129 |
+
ffn_config['lbl_process_group'] = lbl_process_group
|
130 |
+
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio)
|
131 |
+
ffn_config.setdefault('ffn_hidden_size', ffn_hidden_size)
|
132 |
+
args_to_keep_in_ffn_config = inspect.signature(megablocks.layers.arguments.Arguments).parameters
|
133 |
+
ffn_config = {k: v for k, v in ffn_config.items() if k in args_to_keep_in_ffn_config}
|
134 |
+
args = megablocks.layers.arguments.Arguments(hidden_size=d_model, **ffn_config)
|
135 |
+
ffn_config['args'] = args
|
136 |
+
ffn_config['device_mesh'] = device_mesh
|
137 |
+
ffn_config['moe_world_size'] = moe_world_size
|
138 |
+
ffn_config['ffn_type'] = ffn_type
|
139 |
+
ffn_config['fc_type'] = fc_type
|
140 |
+
ffn_config['ffn_act_fn'] = ffn_act_fn
|
141 |
+
return ffn_config
|
142 |
+
|
143 |
+
def config_moe_args(ffn_config: dict, d_model: int, expansion_ratio: Union[int, float], n_layers: int) -> dict:
|
144 |
+
"""Configures `ffn_config` for MoE.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
ffn_config (dict): FFN configuration before the MoE is configured.
|
148 |
+
d_model (int): Hidden size of the network.
|
149 |
+
expansion_ratio (int, float): Expansion ratio in FFN.
|
150 |
+
n_layers (int): Number of blocks used in the network.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
ffn_config (dict): FFN configuration with MoE configured.
|
154 |
+
"""
|
155 |
+
if ffn_config['ffn_type'] in ffns_with_megablocks:
|
156 |
+
return config_megablocks_moe_args(ffn_config=ffn_config, d_model=d_model, expansion_ratio=expansion_ratio, n_layers=n_layers, get_device_mesh=get_megablocks_device_mesh)
|
157 |
+
else:
|
158 |
+
ffn_type = ffn_config['ffn_type']
|
159 |
+
raise ValueError(f'Invalid ffn_type ({ffn_type}).')
|
configuration_mpt.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A HuggingFace-style model configuration."""
|
2 |
+
import copy
|
3 |
+
from .dmoe import _UniformExpertAssignment
|
4 |
+
from .blocks import MPTBlock
|
5 |
+
from .fc import *
|
6 |
+
from .ffn import quickgelu_activation
|
7 |
+
from .config_moe_args import create_process_group_ranks
|
8 |
+
from .layer_builders import build_norm
|
9 |
+
from .mpt_param_count import module_n_params
|
10 |
+
from .norm import _cast_if_autocast_enabled
|
11 |
+
from .act_ckpt import pass_on_block_idx
|
12 |
+
from .custom_embedding import SharedEmbedding
|
13 |
+
from .registry_utils import TypedRegistry
|
14 |
+
from .param_init_fns import torch_default_param_init_fn_
|
15 |
+
import warnings
|
16 |
+
from typing import Any, Optional, Union
|
17 |
+
from transformers import PretrainedConfig
|
18 |
+
from .layers_registry import ffns_with_megablocks
|
19 |
+
from .attention import check_alibi_support, is_flash_v2_installed
|
20 |
+
from .config_defaults import attn_config_defaults, fc_type_defaults, ffn_config_defaults, init_config_defaults
|
21 |
+
from .warnings import ExperimentalWarning
|
22 |
+
|
23 |
+
class MPTConfig(PretrainedConfig):
|
24 |
+
model_type = 'mpt'
|
25 |
+
|
26 |
+
def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: Union[int, float]=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Optional[dict]=None, ffn_config: Optional[dict]=None, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', norm_eps: float=1e-05, use_cache: bool=False, init_config: Optional[dict]=None, fc_type: Union[str, dict]='torch', tie_word_embeddings: bool=True, use_pad_tok_in_ffn: bool=True, block_overrides: Optional[dict[str, Any]]=None, final_logit_softcapping: Optional[float]=None, **kwargs: Any):
|
27 |
+
"""The MPT configuration class.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
d_model (int): The size of the embedding dimension of the model.
|
31 |
+
n_heads (int): The number of attention heads.
|
32 |
+
n_layers (int): The number of layers in the model.
|
33 |
+
expansion_ratio (Union[int, float]): The ratio of the up/down scale in the ffn.
|
34 |
+
max_seq_len (int): The maximum sequence length of the model.
|
35 |
+
vocab_size (int): The size of the vocabulary.
|
36 |
+
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
|
37 |
+
emb_pdrop (float): The dropout probability for the embedding layer.
|
38 |
+
learned_pos_emb (bool): Whether to use learned positional embeddings
|
39 |
+
attn_config (Dict): A dictionary used to configure the model's attention module:
|
40 |
+
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
|
41 |
+
attn_pdrop (float): The dropout probability for the attention layers.
|
42 |
+
attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'.
|
43 |
+
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
|
44 |
+
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
|
45 |
+
fused_qkv (bool): Whether to fuse the Wq, Wk, and Wv weight matrices in the attention layer. If True, the weights are fused into a single
|
46 |
+
Wqkv matrix, which can be faster for matmuls. If False, the weights are kept separate. Defaults to True.
|
47 |
+
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
|
48 |
+
this value.
|
49 |
+
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
|
50 |
+
use the default scale of ``1/sqrt(d_keys)``.
|
51 |
+
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
|
52 |
+
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
|
53 |
+
which sub-sequence each token belongs to.
|
54 |
+
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
|
55 |
+
sliding_window_size (int): Window size for sliding window local attention. Defaults to -1, which means no sliding window. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size, i + seqlen_k - seqlen_q + window_size] inclusive. Only works for flash attention v2.3.0 or higher.
|
56 |
+
alibi (bool): Whether to use the alibi bias instead of position embeddings.
|
57 |
+
alibi_bias_max (int): The maximum value of the alibi bias.
|
58 |
+
rope (bool): Whether to use rotary positional embeddings.
|
59 |
+
rope_theta (int): The base frequency for rope.
|
60 |
+
rope_impl (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py).
|
61 |
+
rope_dail_config (Dict): The configuration for the dail implementation of rope.
|
62 |
+
type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf).
|
63 |
+
pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding.
|
64 |
+
xpos_scale_base (float): The scale base for XPos (if using XPos).
|
65 |
+
rope_hf_config (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length).
|
66 |
+
type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla.
|
67 |
+
factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
|
68 |
+
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
|
69 |
+
kv_dim (Optional[int]): For cross-attention only, allow user to specify different input dimensions for kv projections.
|
70 |
+
ffn_config (Dict): A dictionary used to configure the model's ffn module:
|
71 |
+
ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp
|
72 |
+
init_device (str): The device to use for parameter initialization.
|
73 |
+
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
|
74 |
+
no_bias (bool): Whether to use bias in all layers.
|
75 |
+
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
|
76 |
+
norm_type (str): choose type of norm to use
|
77 |
+
norm_eps (float): epsilon value for norm layer
|
78 |
+
use_cache (bool): Whether or not the model should return the last key/values attentions
|
79 |
+
init_config (Dict): A dictionary used to configure the model initialization:
|
80 |
+
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
|
81 |
+
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
|
82 |
+
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
|
83 |
+
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
|
84 |
+
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
|
85 |
+
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
|
86 |
+
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
|
87 |
+
init_std (float): The standard deviation of the normal distribution used to initialize the model,
|
88 |
+
if using the baseline_ parameter initialization scheme.
|
89 |
+
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
|
90 |
+
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
|
91 |
+
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
|
92 |
+
---
|
93 |
+
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
|
94 |
+
fc_type (str | Dict): Choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs. Can
|
95 |
+
also be a dictionary that specifies the fc layer name and any kwargs for the fc layer.
|
96 |
+
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
|
97 |
+
use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
|
98 |
+
block_overrides: This allows for overriding default block configs for certain layers. This must contain `overrides` and `order`. `order` is a nested list which describes the order of the layers. For each kind of layer, specify the `overrides` in the overrides config (default refers to a layer that does not apply any overrides).
|
99 |
+
To specify this model (https://research.character.ai/optimizing-inference/) , the following config will be needed:
|
100 |
+
block_overrides:
|
101 |
+
order:
|
102 |
+
- name: default
|
103 |
+
- repeat: 2
|
104 |
+
order:
|
105 |
+
- name: sliding_window_layer
|
106 |
+
- name: sliding_window_layer_reuse
|
107 |
+
- name: sliding_window_layer
|
108 |
+
- repeat: 2
|
109 |
+
name: sliding_window_layer_reuse
|
110 |
+
- name: reuse_kv_layer
|
111 |
+
overrides:
|
112 |
+
sliding_window_layer:
|
113 |
+
attn_config:
|
114 |
+
sliding_window_size: 1024
|
115 |
+
sliding_window_layer_reuse:
|
116 |
+
attn_config:
|
117 |
+
sliding_window_size: 1024
|
118 |
+
reuse_kv_layer_idx: -1 # Relative index of the layer whose kv cache to reuse
|
119 |
+
reuse_kv_layer:
|
120 |
+
attn_config:
|
121 |
+
reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse
|
122 |
+
final_logit_softcapping (float | None): Softcapping threshold for final logit. Set to None to disable (default value None). Please see https://arxiv.org/pdf/2403.08295 for more details.
|
123 |
+
kwargs (Any): Other relevant keyword arguments.
|
124 |
+
"""
|
125 |
+
self.d_model = d_model
|
126 |
+
self.n_heads = n_heads
|
127 |
+
self.n_layers = n_layers
|
128 |
+
self.expansion_ratio = expansion_ratio
|
129 |
+
if max_seq_len != int(max_seq_len):
|
130 |
+
raise ValueError('max_seq_len must be an integer')
|
131 |
+
self.max_seq_len = int(max_seq_len)
|
132 |
+
self.vocab_size = vocab_size
|
133 |
+
self.resid_pdrop = resid_pdrop
|
134 |
+
self.emb_pdrop = emb_pdrop
|
135 |
+
self.learned_pos_emb = learned_pos_emb
|
136 |
+
self.attn_config = attn_config if attn_config is not None else copy.deepcopy(attn_config_defaults)
|
137 |
+
self.ffn_config = ffn_config if ffn_config is not None else copy.deepcopy(ffn_config_defaults)
|
138 |
+
self.init_device = init_device
|
139 |
+
self.logit_scale = logit_scale
|
140 |
+
self.no_bias = no_bias
|
141 |
+
self.embedding_fraction = embedding_fraction
|
142 |
+
self.norm_type = norm_type
|
143 |
+
self.norm_eps = norm_eps
|
144 |
+
self.use_cache = use_cache
|
145 |
+
self.init_config = init_config if init_config is not None else copy.deepcopy(init_config_defaults)
|
146 |
+
if block_overrides is not None:
|
147 |
+
self._validate_block_overrides(block_overrides)
|
148 |
+
self.block_overrides = block_overrides
|
149 |
+
self.final_logit_softcapping = final_logit_softcapping
|
150 |
+
if isinstance(fc_type, str):
|
151 |
+
fc_type = {'name': fc_type}
|
152 |
+
self.fc_type = fc_type
|
153 |
+
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
|
154 |
+
if 'name' in kwargs:
|
155 |
+
del kwargs['name']
|
156 |
+
if 'loss_fn' in kwargs:
|
157 |
+
del kwargs['loss_fn']
|
158 |
+
if self.attn_config.get('alibi', False) or self.attn_config.get('rope', False):
|
159 |
+
self.learned_pos_emb = False
|
160 |
+
warnings.warn(f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`')
|
161 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
162 |
+
self._validate_config()
|
163 |
+
|
164 |
+
def _validate_block_overrides(self, block_overrides: dict[str, Any]):
|
165 |
+
warnings.warn(ExperimentalWarning('block_overrides'))
|
166 |
+
if 'order' not in block_overrides:
|
167 |
+
raise ValueError('`order` should be defined in block_overrides')
|
168 |
+
if 'overrides' not in block_overrides:
|
169 |
+
raise ValueError('`overrides` should be defined in block_overrides')
|
170 |
+
if 'default' in block_overrides['overrides'].keys():
|
171 |
+
raise ValueError('block overrides cannot be named "default".')
|
172 |
+
|
173 |
+
def _set_config_defaults(self, config: dict[str, Any], config_defaults: dict[str, Any]) -> dict[str, Any]:
|
174 |
+
for k, v in config_defaults.items():
|
175 |
+
if k not in config:
|
176 |
+
config[k] = v
|
177 |
+
elif isinstance(v, dict):
|
178 |
+
config[k] = self._set_config_defaults(config[k] if config[k] is not None else {}, v)
|
179 |
+
return config
|
180 |
+
|
181 |
+
def validate_attention_config(self) -> None:
|
182 |
+
if 'seq_parallel_world_size' in self.attn_config and self.attn_config['seq_parallel_world_size'] is None:
|
183 |
+
del self.attn_config['seq_parallel_world_size']
|
184 |
+
if self.attn_config.get('seq_parallel_world_size', 1) > 1:
|
185 |
+
raise NotImplementedError('Sequence Parallelism is not supported.')
|
186 |
+
|
187 |
+
def _validate_config(self) -> None:
|
188 |
+
self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
|
189 |
+
self.ffn_config = self._set_config_defaults(self.ffn_config, ffn_config_defaults)
|
190 |
+
self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
|
191 |
+
self.fc_type = self._set_config_defaults(self.fc_type, fc_type_defaults)
|
192 |
+
if self.d_model % self.n_heads != 0:
|
193 |
+
raise ValueError('d_model must be divisible by n_heads')
|
194 |
+
if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
|
195 |
+
raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
|
196 |
+
if self.attn_config['attn_impl'] not in ['torch', 'flash']:
|
197 |
+
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
198 |
+
if self.attn_config['alibi'] and (not check_alibi_support(self.attn_config['attn_impl'])):
|
199 |
+
raise NotImplementedError('alibi only implemented with torch and flash (v2.4.2 or higher) attention.')
|
200 |
+
if self.attn_config['attn_uses_sequence_id'] and (not (self.attn_config['attn_impl'] == 'torch' or (self.attn_config['attn_impl'] == 'flash' and is_flash_v2_installed(v2_version='v2.1.2')))):
|
201 |
+
raise NotImplementedError('attn_uses_sequence_id only implemented with torch and flash (v2.1.2 or higher) attention.')
|
202 |
+
if self.attn_config['rope'] and self.attn_config['rope_impl'] not in ['dail', 'hf']:
|
203 |
+
raise ValueError('If rope is being used then rope_impl should be either "dail", or "hf".')
|
204 |
+
if self.attn_config['rope'] and self.attn_config['rope_impl'] == 'hf' and (self.attn_config['rope_hf_config']['type'] not in ['no_scaling', 'linear', 'dynamic', 'llama3']):
|
205 |
+
raise ValueError('If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".')
|
206 |
+
if self.attn_config['rope'] and self.attn_config['rope_impl'] == 'dail':
|
207 |
+
if self.attn_config['rope_dail_config']['type'] not in ['original', 'xpos']:
|
208 |
+
raise ValueError('If using the dail implementation of rope, the type should be one of "original" or "xpos".')
|
209 |
+
if not is_flash_v2_installed(v2_version='2.0.1'):
|
210 |
+
raise ImportError('If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support')
|
211 |
+
if self.attn_config['sliding_window_size'] != -1 and self.attn_config['attn_impl'] == 'flash' and (not is_flash_v2_installed(v2_version='v2.3.0')):
|
212 |
+
raise NotImplementedError('sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).')
|
213 |
+
if self.attn_config['attn_logit_softcapping'] is not None:
|
214 |
+
if self.attn_config['attn_logit_softcapping'] <= 0:
|
215 |
+
raise ValueError('Attention attn_logit_softcapping should be positive.')
|
216 |
+
if self.attn_config['attn_impl'] == 'flash' and (not is_flash_v2_installed(v2_version='v2.6.2')):
|
217 |
+
raise NotImplementedError('Attention attn_logit_softcapping is only implemented with torch attention or flash attention v2.6.2 (or higher).')
|
218 |
+
if self.attn_config['kv_dim'] is not None and self.attn_config['fused_qkv']:
|
219 |
+
raise ValueError('fused_qkv should be False when "kv_dim" is specified.')
|
220 |
+
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
221 |
+
raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
|
222 |
+
if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
|
223 |
+
raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
224 |
+
if self.init_config.get('name', None) is None:
|
225 |
+
raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
|
226 |
+
if not (self.learned_pos_emb or self.attn_config['alibi'] or self.attn_config['rope']):
|
227 |
+
warnings.warn(f'Positional information not being provided to the model using either learned_pos_emb or alibi or rope.')
|
228 |
+
if self.fc_type['name'] == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp':
|
229 |
+
try:
|
230 |
+
import transformer_engine.pytorch as te
|
231 |
+
del te
|
232 |
+
except:
|
233 |
+
raise ImportError('TransformerEngine import failed. `fc_type: te` requires TransformerEngine be installed, ', 'e.g. pip install transformer-engine[pytorch]')
|
234 |
+
self.ffn_config['fc_type'] = self.fc_type
|
235 |
+
if self.ffn_config['ffn_type'] == 'mptgeglu':
|
236 |
+
raise ValueError('API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. ' + 'See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details.')
|
237 |
+
elif self.ffn_config['ffn_type'] in ffns_with_megablocks:
|
238 |
+
self.ffn_config['return_bias'] = False
|
239 |
+
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
|
240 |
+
self.ffn_config['bias'] = not self.no_bias
|
241 |
+
if 'ffn_act_fn' in self.ffn_config.keys():
|
242 |
+
raise ValueError(f'Transformer Engine block does not support custom activation functions.')
|
243 |
+
if not self.use_pad_tok_in_ffn:
|
244 |
+
try:
|
245 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
246 |
+
except:
|
247 |
+
raise ImportError('In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6')
|
248 |
+
self.validate_attention_config()
|
249 |
+
|
250 |
+
@property
|
251 |
+
def allowed_block_overrides(self):
|
252 |
+
return {'attn_config': {'sliding_window_size': None, 'reuse_kv_layer_idx': None}}
|
custom_embedding.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
class SharedEmbedding(nn.Embedding):
|
6 |
+
|
7 |
+
def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
|
8 |
+
if unembed:
|
9 |
+
return F.linear(input, self.weight)
|
10 |
+
return super().forward(input)
|
dmoe.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Callable, Optional, Union
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
|
6 |
+
|
7 |
+
class _UniformExpertAssignment(torch.autograd.Function):
|
8 |
+
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, x: torch.Tensor, num_experts: int):
|
11 |
+
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
|
12 |
+
out = torch.remainder(out, num_experts)
|
13 |
+
return out.view(x.shape)
|
14 |
+
|
15 |
+
class LearnedRouter(torch.nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: Optional[float], moe_normalize_expert_weights: Optional[Union[int, float]], uniform_expert_assignment: bool, device: Optional[torch.device]) -> None:
|
18 |
+
super().__init__()
|
19 |
+
self.hidden_size: int = hidden_size
|
20 |
+
self.moe_num_experts: int = moe_num_experts
|
21 |
+
self.moe_top_k: int = moe_top_k
|
22 |
+
self.moe_jitter_eps: Optional[float] = moe_jitter_eps
|
23 |
+
self.moe_normalize_expert_weights: Optional[Union[int, float]] = moe_normalize_expert_weights
|
24 |
+
self.uniform_expert_assignment: bool = uniform_expert_assignment
|
25 |
+
self.layer: torch.nn.Module = torch.nn.Linear(hidden_size, moe_num_experts, bias=False, device=device)
|
26 |
+
|
27 |
+
def jitter(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
assert self.moe_jitter_eps is not None
|
29 |
+
low: float = 1.0 - self.moe_jitter_eps
|
30 |
+
high: float = 1.0 + self.moe_jitter_eps
|
31 |
+
noise: torch.Tensor = torch.rand(x.size(), dtype=x.dtype, device=x.device)
|
32 |
+
return low + noise * (high - low)
|
33 |
+
|
34 |
+
def _top_k(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
35 |
+
if self.moe_top_k == 1:
|
36 |
+
values, indices = scores.max(dim=-1)
|
37 |
+
return (values.unsqueeze(-1), indices.unsqueeze(-1))
|
38 |
+
return torch.topk(scores, self.moe_top_k, dim=-1)
|
39 |
+
|
40 |
+
def forward(self, x: torch.Tensor):
|
41 |
+
if self.training and self.moe_jitter_eps is not None:
|
42 |
+
x = x * self.jitter(x)
|
43 |
+
scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
|
44 |
+
expert_weights, top_experts = self._top_k(scores)
|
45 |
+
if self.moe_normalize_expert_weights:
|
46 |
+
expert_weights = expert_weights / torch.norm(expert_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
|
47 |
+
top_experts = _UniformExpertAssignment.apply(top_experts, self.moe_num_experts) if self.uniform_expert_assignment else top_experts
|
48 |
+
scores = scores.to(x.dtype)
|
49 |
+
expert_weights = expert_weights.to(x.dtype)
|
50 |
+
return (scores, expert_weights, top_experts)
|
51 |
+
|
52 |
+
class MLP(torch.nn.Module):
|
53 |
+
|
54 |
+
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]) -> None:
|
55 |
+
super().__init__()
|
56 |
+
self.moe_num_experts: int = moe_num_experts
|
57 |
+
self.ffn_hidden_size: int = ffn_hidden_size
|
58 |
+
self.hidden_size: int = hidden_size
|
59 |
+
self.activation_fn: Callable = activation_fn
|
60 |
+
self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
|
61 |
+
self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
|
62 |
+
self.activation_fn = activation_fn
|
63 |
+
|
64 |
+
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
|
65 |
+
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
|
66 |
+
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
|
67 |
+
before_activation = x @ expert_w1.t()
|
68 |
+
layer_1_output = self.activation_fn(before_activation)
|
69 |
+
output = layer_1_output @ expert_w2
|
70 |
+
return output
|
71 |
+
|
72 |
+
class GLU(torch.nn.Module):
|
73 |
+
|
74 |
+
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]):
|
75 |
+
super().__init__()
|
76 |
+
self.hidden_size = hidden_size
|
77 |
+
self.ffn_hidden_size = ffn_hidden_size
|
78 |
+
self.moe_num_experts = moe_num_experts
|
79 |
+
self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
|
80 |
+
self.v1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
|
81 |
+
self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
|
82 |
+
self.activation_fn = activation_fn
|
83 |
+
|
84 |
+
def forward(self, x: torch.Tensor, expert_idx: torch.Tensor):
|
85 |
+
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
|
86 |
+
expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
|
87 |
+
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
|
88 |
+
x1 = x.matmul(expert_w1.t())
|
89 |
+
x2 = x.matmul(expert_v1.t())
|
90 |
+
x1 = self.activation_fn(x1)
|
91 |
+
x1 = x1 * x2
|
92 |
+
x1 = x1.matmul(expert_w2)
|
93 |
+
return x1
|
94 |
+
|
95 |
+
class DroplessMLP(torch.nn.Module):
|
96 |
+
|
97 |
+
def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, moe_num_experts: int, activation_fn: Callable, bias: bool, device: Optional[torch.device]):
|
98 |
+
super().__init__()
|
99 |
+
self.moe_num_experts = moe_num_experts
|
100 |
+
if mlp_type == 'mlp':
|
101 |
+
self.mlp = MLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device)
|
102 |
+
elif mlp_type == 'glu':
|
103 |
+
self.mlp = GLU(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device)
|
104 |
+
else:
|
105 |
+
raise ValueError(f'Received unknown mlp_type={mlp_type!r}')
|
106 |
+
|
107 |
+
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
108 |
+
in_shape = x.shape
|
109 |
+
hidden_size = in_shape[-1]
|
110 |
+
x = x.view(-1, hidden_size)
|
111 |
+
out = torch.zeros_like(x)
|
112 |
+
expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
|
113 |
+
for expert_idx in range(0, self.moe_num_experts):
|
114 |
+
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
|
115 |
+
if token_idx.shape[0] == 0:
|
116 |
+
continue
|
117 |
+
token_list = token_idx.tolist()
|
118 |
+
topk_list = topk_idx.tolist()
|
119 |
+
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
|
120 |
+
mlp_output = self.mlp(expert_tokens, expert_idx)
|
121 |
+
expert_weights = expert_weights.to(mlp_output.device)
|
122 |
+
expert_out = mlp_output * expert_weights[token_list, topk_list, None]
|
123 |
+
out = out.to(mlp_output.device)
|
124 |
+
token_idx = token_idx.to(mlp_output.device)
|
125 |
+
out.index_add_(0, token_idx, expert_out)
|
126 |
+
out = out.view(in_shape)
|
127 |
+
return out
|
128 |
+
|
129 |
+
class dMoE(torch.nn.Module):
|
130 |
+
|
131 |
+
def __init__(self, device: Optional[torch.device], hidden_size: int=1024, ffn_hidden_size: int=4096, moe_num_experts: int=1, moe_top_k: int=1, mlp_type: str='mlp', activation_fn: Callable=DEFAULT_ACTIVATION_FN, moe_jitter_eps: Optional[float]=None, moe_normalize_expert_weights: Optional[Union[int, float]]=None, uniform_expert_assignment: bool=False, bias: bool=True):
|
132 |
+
super().__init__()
|
133 |
+
self.router = LearnedRouter(hidden_size, moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, moe_jitter_eps=moe_jitter_eps, moe_normalize_expert_weights=moe_normalize_expert_weights, uniform_expert_assignment=uniform_expert_assignment, device=device)
|
134 |
+
self.experts = DroplessMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, mlp_type=mlp_type, moe_num_experts=moe_num_experts, activation_fn=activation_fn, bias=bias, device=device)
|
135 |
+
|
136 |
+
def forward(self, x: torch.Tensor):
|
137 |
+
scores, expert_weights, top_experts = self.router(x)
|
138 |
+
return self.experts(x, scores, expert_weights, top_experts)
|
fc.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from .layers_registry import fcs
|
3 |
+
fcs.register('torch', func=nn.Linear)
|
4 |
+
try:
|
5 |
+
import transformer_engine.pytorch as te
|
6 |
+
fcs.register('te', func=te.Linear)
|
7 |
+
except:
|
8 |
+
pass
|
ffn.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MPT Blocks used for the MPT Model."""
|
2 |
+
import logging
|
3 |
+
from copy import deepcopy
|
4 |
+
from functools import partial
|
5 |
+
from typing import Any, Callable, Optional, Union
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.distributed import ProcessGroup
|
9 |
+
from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard
|
10 |
+
from .layers_registry import ffns, ffns_with_megablocks, ffns_with_norm
|
11 |
+
from .dmoe import dMoE
|
12 |
+
from .layer_builders import build_fc
|
13 |
+
from .config_defaults import fc_type_defaults
|
14 |
+
try:
|
15 |
+
import transformer_engine.pytorch as te
|
16 |
+
is_te_imported = True
|
17 |
+
except ModuleNotFoundError:
|
18 |
+
is_te_imported = False
|
19 |
+
try:
|
20 |
+
import megablocks
|
21 |
+
is_megablocks_imported = True
|
22 |
+
except ModuleNotFoundError:
|
23 |
+
is_megablocks_imported = False
|
24 |
+
log = logging.getLogger(__name__)
|
25 |
+
_FFN_ACT_FN_DEFAULT = {'name': 'gelu', 'approximate': 'none'}
|
26 |
+
|
27 |
+
def quickgelu_activation(input: torch.Tensor) -> torch.Tensor:
|
28 |
+
"""Applies GELU approximation that is fast but somewhat inaccurate.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
input (torch.Tensor): Input tensor of shape(*), where * means any
|
32 |
+
number of dimensions
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
torch.Tensor: Tensor with same shape as input tensor
|
36 |
+
"""
|
37 |
+
return input * torch.sigmoid(1.702 * input)
|
38 |
+
|
39 |
+
def resolve_ffn_act_fn(config: Optional[dict]=None) -> Callable[[torch.Tensor], torch.Tensor]:
|
40 |
+
"""Resolve the activation function for the feed-forward network.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
config (Optional[dict]): The configuration dictionary for the activation function.
|
44 |
+
The dict config must specify the 'name' of a torch.nn.functional activation
|
45 |
+
function. All of other key values pairs are bound to the function as a partial.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Callable[[torch.Tensor], torch.Tensor]: The activation function.
|
49 |
+
"""
|
50 |
+
if config is None:
|
51 |
+
config = _FFN_ACT_FN_DEFAULT
|
52 |
+
config = deepcopy(config)
|
53 |
+
name = config.pop('name')
|
54 |
+
if name == 'quick_gelu':
|
55 |
+
return quickgelu_activation
|
56 |
+
else:
|
57 |
+
if not hasattr(torch.nn.functional, name):
|
58 |
+
raise ValueError(f'Unrecognized activation function name ({name}).')
|
59 |
+
act = getattr(torch.nn.functional, name)
|
60 |
+
return partial(act, **config)
|
61 |
+
_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)
|
62 |
+
|
63 |
+
def resolve_ffn_hidden_size(d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int]=None) -> int:
|
64 |
+
"""Resolve the hidden size of the feed-forward network.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
d_model (int): The dimension of the input and output of the feed-forward network.
|
68 |
+
expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network.
|
69 |
+
ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
int: The hidden size of the feed-forward network.
|
73 |
+
"""
|
74 |
+
if ffn_hidden_size is not None:
|
75 |
+
log.info(f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.')
|
76 |
+
else:
|
77 |
+
ffn_hidden_size = int(d_model * expansion_ratio)
|
78 |
+
if ffn_hidden_size != d_model * expansion_ratio:
|
79 |
+
raise ValueError(f'`d_model * expansion_ratio` must be an integer (d_model={d_model!r}; expansion_ratio={expansion_ratio!r}; d_model * expansion_ratio={d_model * expansion_ratio!r}).')
|
80 |
+
return ffn_hidden_size
|
81 |
+
|
82 |
+
def dtensorify_param(param: nn.Parameter, mesh: DeviceMesh, placements: list[Placement]):
|
83 |
+
"""Construct a DTensor from an already sharded local parameter."""
|
84 |
+
param_dtensor = DTensor.from_local(param.data, device_mesh=mesh, placements=placements, run_check=False)
|
85 |
+
return nn.Parameter(param_dtensor)
|
86 |
+
|
87 |
+
class MPTMLP(nn.Module):
|
88 |
+
|
89 |
+
def __init__(self, d_model: int, expansion_ratio: Union[int, float], fc_type: Optional[dict[str, Any]]=None, ffn_hidden_size: Optional[int]=None, act_fn: Callable[[torch.Tensor], torch.Tensor]=_DEFAULT_ACT_FN, device: Optional[str]=None, bias: bool=True):
|
90 |
+
super().__init__()
|
91 |
+
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
|
92 |
+
if fc_type is None:
|
93 |
+
fc_type = fc_type_defaults
|
94 |
+
fc_type['bias'] = bias
|
95 |
+
fc_type['device'] = device
|
96 |
+
self.fc_type = fc_type
|
97 |
+
self.fc_type_name = self.fc_type['name']
|
98 |
+
self.up_proj = build_fc(name=self.fc_type_name, in_features=d_model, out_features=ffn_hidden_size, fc_kwargs=self.fc_type)
|
99 |
+
self.act = act_fn
|
100 |
+
self.down_proj = build_fc(name=self.fc_type_name, in_features=ffn_hidden_size, out_features=d_model, fc_kwargs=self.fc_type)
|
101 |
+
self.down_proj._is_residual = True
|
102 |
+
|
103 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
104 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
105 |
+
|
106 |
+
class MPTGLU(MPTMLP):
|
107 |
+
|
108 |
+
def __init__(self, d_model: int, expansion_ratio: Union[int, float], fc_type: Optional[dict[str, Any]]=None, ffn_hidden_size: Optional[int]=None, act_fn: Callable[[torch.Tensor], torch.Tensor]=_DEFAULT_ACT_FN, device: Optional[str]=None, bias: bool=True):
|
109 |
+
super().__init__(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, act_fn=act_fn, device=device, bias=bias)
|
110 |
+
self.gate_proj = build_fc(name=self.fc_type_name, in_features=d_model, out_features=self.up_proj.out_features, fc_kwargs=self.fc_type)
|
111 |
+
|
112 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
113 |
+
return self.down_proj(self.act(self.gate_proj(x)).to(device=x.device) * self.up_proj(x))
|
114 |
+
|
115 |
+
def build_mptglu(d_model: int, expansion_ratio: Union[int, float], fc_type: Optional[dict[str, Any]]=None, ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True) -> nn.Module:
|
116 |
+
return MPTGLU(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, act_fn=resolve_ffn_act_fn(ffn_act_fn), ffn_hidden_size=ffn_hidden_size, device=device, bias=bias)
|
117 |
+
|
118 |
+
def build_mptmlp(d_model: int, expansion_ratio: Union[int, float], fc_type: Optional[dict[str, Any]]=None, ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True) -> nn.Module:
|
119 |
+
return MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, act_fn=resolve_ffn_act_fn(ffn_act_fn), ffn_hidden_size=ffn_hidden_size, device=device, bias=bias)
|
120 |
+
|
121 |
+
def build_te_ln_mlp(d_model: int, expansion_ratio: Union[int, float], fc_type: Optional[dict[str, Any]]=None, ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module:
|
122 |
+
assert te is not None
|
123 |
+
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
|
124 |
+
if ffn_act_fn is not None:
|
125 |
+
raise ValueError(f'Transformer Engine block does not support custom activation functions.')
|
126 |
+
return te.LayerNormMLP(hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs)
|
127 |
+
|
128 |
+
def build_torch_dmoe(d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module:
|
129 |
+
moe_num_experts = kwargs.pop('moe_num_experts')
|
130 |
+
moe_top_k = kwargs.pop('moe_top_k')
|
131 |
+
mlp_type = kwargs.pop('mlp_type')
|
132 |
+
moe_jitter_eps = kwargs.pop('moe_jitter_eps')
|
133 |
+
moe_normalize_expert_weights = kwargs.pop('moe_normalize_expert_weights')
|
134 |
+
uniform_expert_assignment = kwargs.pop('uniform_expert_assignment')
|
135 |
+
fc_type = kwargs.pop('fc_type', None)
|
136 |
+
del fc_type
|
137 |
+
if len(kwargs) > 0:
|
138 |
+
raise ValueError(f'Invalid arguments to torch dmoe: {kwargs}.')
|
139 |
+
return dMoE(hidden_size=d_model, ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size), moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, mlp_type=mlp_type, bias=bias, moe_jitter_eps=moe_jitter_eps, activation_fn=resolve_ffn_act_fn(ffn_act_fn), moe_normalize_expert_weights=moe_normalize_expert_weights, uniform_expert_assignment=uniform_expert_assignment, device=torch.device(device) if device is not None else None)
|
140 |
+
|
141 |
+
def mb_setup_args(d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int], ffn_act_fn: Optional[dict], device: Optional[str], bias: bool, kwargs: dict[str, Any]) -> tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]:
|
142 |
+
"""Setup the MegaBlocks args.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
d_model (int): The dimension of the input and output of the FFN.
|
146 |
+
expansion_ratio (Union[int, float]): The expansion ratio of the FFN.
|
147 |
+
ffn_hidden_size (Optional[int]): The hidden size of the FFN.
|
148 |
+
ffn_act_fn (Optional[dict]): The activation function of the FFN.
|
149 |
+
device (Optional[str]): The device to run the FFN on.
|
150 |
+
bias (bool): Whether to include bias in the FFN.
|
151 |
+
kwargs (dict[str, Any]): Additional kwargs.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
tuple['megablocks.layers.arguments.Arguments', int, ProcessGroup]:
|
155 |
+
The MegaBlocks args, the MoE world size, and the expert parallel group.
|
156 |
+
"""
|
157 |
+
if megablocks is None:
|
158 |
+
raise RuntimeError('Requirements for megablocks not installed; see install instructions in `README.md`.')
|
159 |
+
args = kwargs['args']
|
160 |
+
args.bias = bias
|
161 |
+
args.hidden_size = d_model
|
162 |
+
args.device = device
|
163 |
+
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
|
164 |
+
args.ffn_hidden_size = ffn_hidden_size
|
165 |
+
if ffn_act_fn is not None:
|
166 |
+
args.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
|
167 |
+
moe_world_size = 1
|
168 |
+
expert_parallel_group = args.expert_parallel_group
|
169 |
+
if expert_parallel_group is not None:
|
170 |
+
moe_world_size = expert_parallel_group.size()
|
171 |
+
if kwargs.get('moe_world_size') != moe_world_size:
|
172 |
+
raise RuntimeError(f'MoE expert_parallel_group configured with incorrect world size.')
|
173 |
+
return (args, moe_world_size, expert_parallel_group)
|
174 |
+
|
175 |
+
def attach_ffn_mb_args(ffn: nn.Module, expert_parallel_group: ProcessGroup, args: 'megablocks.layers.arguments.Arguments'):
|
176 |
+
"""Attach arguments used in parameter initialization to the FFN.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
ffn (nn.Module): The FFN module.
|
180 |
+
expert_parallel_group (ProcessGroup): The expert parallel process group.
|
181 |
+
args (megablocks.layers.arguments.Arguments): The arguments for MegaBlocks.
|
182 |
+
"""
|
183 |
+
ffn.experts.mlp.hidden_size = args.ffn_hidden_size
|
184 |
+
ffn.experts.mlp.expert_parallel_group = expert_parallel_group
|
185 |
+
|
186 |
+
def get_fsdp_submesh_2d(device_mesh: DeviceMesh):
|
187 |
+
"""Get the submesh for FSDP.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
device_mesh (DeviceMesh): The full device mesh.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
DeviceMesh: The submesh for FSDP.
|
194 |
+
"""
|
195 |
+
if device_mesh.mesh.ndim == 2:
|
196 |
+
submesh = device_mesh['weight_parallel']
|
197 |
+
elif device_mesh.mesh.ndim == 3:
|
198 |
+
raise RuntimeError(f'HSDP + MoE is not supported.')
|
199 |
+
else:
|
200 |
+
raise ValueError(f'device_mesh.mesh.ndim={device_mesh.mesh.ndim!r} not supported for MoE.')
|
201 |
+
return submesh
|
202 |
+
|
203 |
+
def set_ffn_device_mesh(ffn: nn.Module, moe_world_size: int, device_mesh: DeviceMesh, get_fsdp_submesh: Callable[[DeviceMesh], DeviceMesh]):
|
204 |
+
"""Sets the device mesh in FSDP kwargs.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
ffn (nn.Module): The FFN module.
|
208 |
+
moe_world_size (int): The MoE world size.
|
209 |
+
device_mesh (DeviceMesh): The full device mesh.
|
210 |
+
get_fsdp_submesh (Callable[[DeviceMesh], DeviceMesh]): A function to get the fsdp submesh.
|
211 |
+
|
212 |
+
Raises:
|
213 |
+
RuntimeError: If the device mesh is 3D.
|
214 |
+
ValueError: If the device mesh is not 2D or 3D.
|
215 |
+
"""
|
216 |
+
if moe_world_size > 1:
|
217 |
+
expert_mesh = device_mesh['expert_parallel']
|
218 |
+
expert_placements: list[Placement] = [Shard(0)]
|
219 |
+
dtensorified_params = [(name, dtensorify_param(param=parameter, mesh=expert_mesh, placements=expert_placements)) for name, parameter in ffn.experts.mlp.named_parameters()]
|
220 |
+
for name, dtensorified_param in dtensorified_params:
|
221 |
+
ffn.experts.mlp.register_parameter(name, dtensorified_param)
|
222 |
+
submesh = get_fsdp_submesh(device_mesh)
|
223 |
+
ffn.experts._fsdp_kwargs_dict = {'device_mesh': submesh}
|
224 |
+
|
225 |
+
def moe_fused_init_setup(ffn: nn.Module):
|
226 |
+
"""Attach the _stack_dim attribute to the FFN.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
ffn (nn.Module): The FFN module.
|
230 |
+
"""
|
231 |
+
ffn.experts.mlp._stack_dim = 0
|
232 |
+
|
233 |
+
def build_mb_moe(d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module:
|
234 |
+
if not is_megablocks_imported:
|
235 |
+
raise RuntimeError('Requirements for megablocks not installed; see install instructions in `README.md`.')
|
236 |
+
args, moe_world_size, expert_parallel_group = mb_setup_args(d_model=d_model, expansion_ratio=expansion_ratio, ffn_hidden_size=ffn_hidden_size, ffn_act_fn=ffn_act_fn, device=device, bias=bias, kwargs=kwargs)
|
237 |
+
ffn = megablocks.layers.moe.MoE(args)
|
238 |
+
moe_fused_init_setup(ffn=ffn)
|
239 |
+
attach_ffn_mb_args(ffn=ffn, expert_parallel_group=expert_parallel_group, args=args)
|
240 |
+
set_ffn_device_mesh(ffn=ffn, moe_world_size=moe_world_size, device_mesh=kwargs['device_mesh'], get_fsdp_submesh=get_fsdp_submesh_2d)
|
241 |
+
return ffn
|
242 |
+
|
243 |
+
def dmoe_fused_init_setup(ffn: nn.Module, args: 'megablocks.layers.arguments.Arguments', moe_world_size: int):
|
244 |
+
"""Attach the _fused attribute to the dMoE model.
|
245 |
+
|
246 |
+
This is used for parameter initialization.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
ffn (nn.Module): The FFN module.
|
250 |
+
args (megablocks.layers.arguments.Arguments): The arguments for MegaBlocks.
|
251 |
+
moe_world_size (int): The MoE world size.
|
252 |
+
"""
|
253 |
+
n_exp = min(1, args.moe_num_experts // moe_world_size)
|
254 |
+
ffn.experts.mlp._fused = (0, [(n + 1) * args.ffn_hidden_size for n in range(n_exp - 1)])
|
255 |
+
|
256 |
+
def build_mb_dmoe(d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module:
|
257 |
+
if not is_megablocks_imported:
|
258 |
+
raise RuntimeError('Requirements for megablocks not installed; see install instructions in `README.md`.')
|
259 |
+
args, moe_world_size, expert_parallel_group = mb_setup_args(d_model=d_model, expansion_ratio=expansion_ratio, ffn_hidden_size=ffn_hidden_size, ffn_act_fn=ffn_act_fn, device=device, bias=bias, kwargs=kwargs)
|
260 |
+
ffn = megablocks.layers.dmoe.dMoE(args)
|
261 |
+
dmoe_fused_init_setup(ffn=ffn, args=args, moe_world_size=moe_world_size)
|
262 |
+
attach_ffn_mb_args(ffn=ffn, expert_parallel_group=expert_parallel_group, args=args)
|
263 |
+
set_ffn_device_mesh(ffn=ffn, moe_world_size=moe_world_size, device_mesh=kwargs['device_mesh'], get_fsdp_submesh=get_fsdp_submesh_2d)
|
264 |
+
return ffn
|
265 |
+
ffns.register('mptglu', func=build_mptglu)
|
266 |
+
ffns.register('mptmlp', func=build_mptmlp)
|
267 |
+
ffns.register('torch_dmoe', func=build_torch_dmoe)
|
268 |
+
if is_te_imported:
|
269 |
+
ffns_with_norm.register('te_ln_mlp', func=build_te_ln_mlp)
|
270 |
+
if is_megablocks_imported:
|
271 |
+
ffns_with_megablocks.register('mb_moe', func=build_mb_moe)
|
272 |
+
ffns_with_megablocks.register('mb_dmoe', func=build_mb_dmoe)
|
generation_config.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.51.3",
|
4 |
+
"use_cache": false
|
5 |
+
}
|
layer_builders.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional, Union
|
2 |
+
import torch
|
3 |
+
from .layers_registry import attention_classes, fcs, ffns, ffns_with_megablocks, ffns_with_norm, norms
|
4 |
+
from .registry_utils import construct_from_registry
|
5 |
+
|
6 |
+
def build_norm(name: str, normalized_shape: Union[int, list[int], torch.Size], eps: Optional[float]=1e-05, device: Optional[str]=None):
|
7 |
+
kwargs = {'normalized_shape': normalized_shape, 'eps': eps, 'device': device}
|
8 |
+
return construct_from_registry(name=name, registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs)
|
9 |
+
|
10 |
+
def build_ffn(name: str, d_model: int, expansion_ratio: float, device: Optional[str], bias: bool, ffn_kwargs: dict[str, Any]):
|
11 |
+
registry_to_use = ffns
|
12 |
+
if name in ffns_with_norm:
|
13 |
+
registry_to_use = ffns_with_norm
|
14 |
+
if name in ffns_with_megablocks:
|
15 |
+
registry_to_use = ffns_with_megablocks
|
16 |
+
kwargs = {'d_model': d_model, 'expansion_ratio': expansion_ratio, 'device': device, 'bias': bias, **{k: v for k, v in ffn_kwargs.items() if k != 'ffn_type'}}
|
17 |
+
|
18 |
+
def _validation_function(maybe_module: Any):
|
19 |
+
if not isinstance(maybe_module, torch.nn.Module):
|
20 |
+
raise ValueError(f'Function {name} must return a torch.nn.Module.')
|
21 |
+
result = construct_from_registry(name=name, registry=registry_to_use, post_validation_function=_validation_function, partial_function=False, kwargs=kwargs)
|
22 |
+
if name in ffns_with_norm:
|
23 |
+
result._has_norm = True
|
24 |
+
if name in ffns_with_megablocks:
|
25 |
+
result._uses_megablocks = True
|
26 |
+
return result
|
27 |
+
|
28 |
+
def build_attention_layer(name: str, attn_kwargs: dict[str, Any]):
|
29 |
+
return construct_from_registry(name=name, registry=attention_classes, pre_validation_function=torch.nn.Module, kwargs=attn_kwargs)
|
30 |
+
|
31 |
+
def build_fc(name: str, in_features: int, out_features: int, fc_kwargs: dict[str, Any]):
|
32 |
+
kwargs = {'in_features': in_features, 'out_features': out_features, **{k: v for k, v in fc_kwargs.items() if k != 'name'}}
|
33 |
+
return construct_from_registry(name=name, registry=fcs, pre_validation_function=torch.nn.Module, kwargs=kwargs)
|
layers_registry.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
import torch
|
3 |
+
from .registry_utils import create_registry
|
4 |
+
_norms_description = 'The norms registry is used to register classes that implement normalization layers.\n\n One example of this is torch.nn.LayerNorm. See norm.py for examples.\n\n Args:\n normalized_shape Union[int, List[int], torch.Size]: The shape of the input tensor.\n device: Optional[torch.device]: The device to use for the normalization layer.\n\n Returns:\n torch.nn.Module: The normalization layer.\n '
|
5 |
+
norms = create_registry('llmfoundry', 'norms', generic_type=type[torch.nn.Module], entry_points=True, description=_norms_description)
|
6 |
+
_fcs_description = 'The fcs registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).\n\n See fc.py for examples.\n\n Args:\n in_features: int: The number of input features.\n out_features: int: The number of output features.\n kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer.\n\n Returns:\n torch.nn.Module: The fully connected layer.\n '
|
7 |
+
fcs = create_registry('llmfoundry', 'fcs', generic_type=type[torch.nn.Module], entry_points=True, description=_fcs_description)
|
8 |
+
_ffns_description = 'The ffns registry is used to register functions that build FFN layers.\n\n These layers are generally composed of fc layers and activation functions.\n One example is MPTMLP. See ffn.py for examples.\n\n Args:\n d_model: int: The size of the input and output tensors.\n expansion_ratio: float: The expansion ratio for the hidden layer.\n device: Optional[str]: The device to use for the layer.\n bias: bool: Whether or not to include a bias term.\n kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer.\n\n Returns:\n torch.nn.Module: The FFN layer.\n '
|
9 |
+
ffns = create_registry('llmfoundry', 'ffns', generic_type=Callable, entry_points=True, description=_ffns_description)
|
10 |
+
_ffns_with_norm_description = 'The ffns_with_norm registry is used to register functions that build FFN layers with normalization.\n\n The resulting layer will have ._has_norm set on it.\n One example is te.LayerNormMLP. See ffn.py for examples.\n\n Args:\n d_model: int: The size of the input and output tensors.\n expansion_ratio: float: The expansion ratio for the hidden layer.\n device: Optional[str]: The device to use for the layer.\n bias: bool: Whether or not to include a bias term.\n kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer.\n\n Returns:\n torch.nn.Module: The FFN layer.\n '
|
11 |
+
ffns_with_norm = create_registry('llmfoundry', 'ffns_with_norm', generic_type=Callable, entry_points=True, description=_ffns_with_norm_description)
|
12 |
+
_ffns_with_megablocks_description = 'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.' + 'See ffn.py for examples.'
|
13 |
+
_ffns_with_megablocks_description = 'The ffns_with_megablocks registry is used to register functions that build FFN layers using MegaBlocks.\n\n The resulting layer will have ._uses_megablocks set on it.\n One example is megablocks.layers.dmoe.dMoE. See ffn.py for examples.\n\n Returns:\n torch.nn.Module: The FFN layer.\n '
|
14 |
+
ffns_with_megablocks = create_registry('llmfoundry', 'ffns_with_megablocks', generic_type=Callable, entry_points=True, description=_ffns_with_megablocks_description)
|
15 |
+
_attention_classes_description = 'The attention_classes registry is used to register classes that implement attention layers.\n\n The kwargs are passed directly to the constructor of the class.\n One example is GroupedQueryAttention. See attention.py for examples.\n\n Args:\n kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer.\n\n Returns:\n torch.nn.Module: The attention layer.\n '
|
16 |
+
attention_classes = create_registry('llmfoundry', 'attention_classes', generic_type=type[torch.nn.Module], entry_points=True, description=_attention_classes_description)
|
17 |
+
_attention_implementations_description = "The attention_implementations registry is used to register functions that implement the attention operation.\n\n One example is 'flash'. See attention.py for examples.\n\n Args:\n query (torch.Tensor): The query tensor.\n key (torch.Tensor): The key tensor.\n value (torch.Tensor): The value tensor.\n n_heads (int): The number of attention heads.\n kv_n_heads (int): The number of attention heads for the key and value tensors.\n past_key_value (Optional[tuple[torch.Tensor, torch.Tensor]]): The past key and value tensors.\n softmax_scale (Optional[float]) = None\n attn_bias (Optional[torch.Tensor]) = None\n is_causal (bool) = False\n dropout_p (float) = 0.0\n training (bool) = True\n needs_weights (bool) = False\n kwargs: Dict[str, Any]: Additional keyword arguments the implementation accepts.\n\n Returns:\n tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:\n The output tensor, the attention weights, and the past key and value tensors.\n "
|
18 |
+
attention_implementations = create_registry('llmfoundry', 'attention_implementations', generic_type=Callable, entry_points=True, description=_attention_implementations_description)
|
19 |
+
_param_init_fns_description = "The param_init_fns registry is used to register functions that initialize parameters.\n\n These functions should take in a torch.nn.Module, additional kwargs, and initialize the parameters of the module.\n Generally they can call generic_param_init_fn_ with an appropriate partial function. See param_init_fns.py for examples.\n\n Note: These functions should take in arbitrary kwargs, and discard any they don't need.\n\n Args:\n module: torch.nn.Module: The module to initialize.\n kwargs: Dict[str, Any]: Additional keyword arguments to use for initialization.\n "
|
20 |
+
param_init_fns = create_registry('llmfoundry', 'param_init_fns', generic_type=Callable[..., None], entry_points=True, description=_param_init_fns_description)
|
21 |
+
_module_init_fns_description = 'The module_init_fns registry is used to register functions that initialize specific modules.\n\n These functions should return True if they initialize the module, and False otherwise.\n This allows them to be called without knowing their contents. They should take in the module and additional kwargs.\n If multiple functions can initialize the module, the one that is registered first will be used, so it is recommended to\n override an existing function if you want to change existing initialization behavior, and add new functions if you have new\n layer types. See param_init_fns.py for details.\n\n Args:\n module: torch.nn.Module: The module to initialize.\n kwargs: Dict[str, Any]: Additional keyword arguments to use for initialization.\n\n Returns:\n bool: Whether or not the module was initialized.\n '
|
22 |
+
module_init_fns = create_registry('llmfoundry', 'module_init_fns', generic_type=Callable[..., bool], entry_points=True, description=_module_init_fns_description)
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model-00001-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81285ba85106fae1bd4888299ae53c8b0d9d7931fd4fe55b63586217f4f17686
|
3 |
+
size 984249104
|
model-00002-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:01e48f4e42af9c0ca90ce8126815ff3b908cd88550c651bc7ce6d9fe8d66257c
|
3 |
+
size 988048544
|
model-00003-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8bc0d3fcaca8bdfccf6550339ccb9c0409b15a5396261ca181f76a02eb43947
|
3 |
+
size 975467784
|
model-00004-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d4f3ef10def6e5d48bb5adf4d016574116c410cbfe90568abd417e57c63abde
|
3 |
+
size 931425488
|
model-00005-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2352c5ee8b3fa13b3ec3d3b07e6e2caa9b2a9990bcccaf57e02d6b89b667eafe
|
3 |
+
size 931413312
|
model-00006-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6cef8dff3c993f5a28c5e3ea92430a77f5a2ade4ed7547307ffb7b2063f9d977
|
3 |
+
size 988048592
|
model-00007-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d24632b50ec49b2253c989a627771dfabe2e9ecb42e455deffbe198244fbaa1e
|
3 |
+
size 774080720
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 6572701696
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"transformer.blocks.0.attn.Wqkv.bias": "model-00001-of-00007.safetensors",
|
7 |
+
"transformer.blocks.0.attn.Wqkv.weight": "model-00001-of-00007.safetensors",
|
8 |
+
"transformer.blocks.0.attn.out_proj.bias": "model-00001-of-00007.safetensors",
|
9 |
+
"transformer.blocks.0.attn.out_proj.weight": "model-00001-of-00007.safetensors",
|
10 |
+
"transformer.blocks.0.ffn.down_proj.bias": "model-00002-of-00007.safetensors",
|
11 |
+
"transformer.blocks.0.ffn.down_proj.weight": "model-00002-of-00007.safetensors",
|
12 |
+
"transformer.blocks.0.ffn.up_proj.bias": "model-00002-of-00007.safetensors",
|
13 |
+
"transformer.blocks.0.ffn.up_proj.weight": "model-00002-of-00007.safetensors",
|
14 |
+
"transformer.blocks.0.norm_1.weight": "model-00001-of-00007.safetensors",
|
15 |
+
"transformer.blocks.0.norm_2.weight": "model-00001-of-00007.safetensors",
|
16 |
+
"transformer.blocks.1.attn.Wqkv.bias": "model-00002-of-00007.safetensors",
|
17 |
+
"transformer.blocks.1.attn.Wqkv.weight": "model-00002-of-00007.safetensors",
|
18 |
+
"transformer.blocks.1.attn.out_proj.bias": "model-00002-of-00007.safetensors",
|
19 |
+
"transformer.blocks.1.attn.out_proj.weight": "model-00002-of-00007.safetensors",
|
20 |
+
"transformer.blocks.1.ffn.down_proj.bias": "model-00002-of-00007.safetensors",
|
21 |
+
"transformer.blocks.1.ffn.down_proj.weight": "model-00002-of-00007.safetensors",
|
22 |
+
"transformer.blocks.1.ffn.up_proj.bias": "model-00002-of-00007.safetensors",
|
23 |
+
"transformer.blocks.1.ffn.up_proj.weight": "model-00002-of-00007.safetensors",
|
24 |
+
"transformer.blocks.1.norm_1.weight": "model-00002-of-00007.safetensors",
|
25 |
+
"transformer.blocks.1.norm_2.weight": "model-00002-of-00007.safetensors",
|
26 |
+
"transformer.blocks.10.attn.Wqkv.bias": "model-00004-of-00007.safetensors",
|
27 |
+
"transformer.blocks.10.attn.Wqkv.weight": "model-00004-of-00007.safetensors",
|
28 |
+
"transformer.blocks.10.attn.out_proj.bias": "model-00004-of-00007.safetensors",
|
29 |
+
"transformer.blocks.10.attn.out_proj.weight": "model-00004-of-00007.safetensors",
|
30 |
+
"transformer.blocks.10.ffn.down_proj.bias": "model-00004-of-00007.safetensors",
|
31 |
+
"transformer.blocks.10.ffn.down_proj.weight": "model-00004-of-00007.safetensors",
|
32 |
+
"transformer.blocks.10.ffn.up_proj.bias": "model-00004-of-00007.safetensors",
|
33 |
+
"transformer.blocks.10.ffn.up_proj.weight": "model-00004-of-00007.safetensors",
|
34 |
+
"transformer.blocks.10.norm_1.weight": "model-00003-of-00007.safetensors",
|
35 |
+
"transformer.blocks.10.norm_2.weight": "model-00004-of-00007.safetensors",
|
36 |
+
"transformer.blocks.11.attn.Wqkv.bias": "model-00004-of-00007.safetensors",
|
37 |
+
"transformer.blocks.11.attn.Wqkv.weight": "model-00004-of-00007.safetensors",
|
38 |
+
"transformer.blocks.11.attn.out_proj.bias": "model-00004-of-00007.safetensors",
|
39 |
+
"transformer.blocks.11.attn.out_proj.weight": "model-00004-of-00007.safetensors",
|
40 |
+
"transformer.blocks.11.ffn.down_proj.bias": "model-00004-of-00007.safetensors",
|
41 |
+
"transformer.blocks.11.ffn.down_proj.weight": "model-00004-of-00007.safetensors",
|
42 |
+
"transformer.blocks.11.ffn.up_proj.bias": "model-00004-of-00007.safetensors",
|
43 |
+
"transformer.blocks.11.ffn.up_proj.weight": "model-00004-of-00007.safetensors",
|
44 |
+
"transformer.blocks.11.norm_1.weight": "model-00004-of-00007.safetensors",
|
45 |
+
"transformer.blocks.11.norm_2.weight": "model-00004-of-00007.safetensors",
|
46 |
+
"transformer.blocks.12.attn.Wqkv.bias": "model-00004-of-00007.safetensors",
|
47 |
+
"transformer.blocks.12.attn.Wqkv.weight": "model-00004-of-00007.safetensors",
|
48 |
+
"transformer.blocks.12.attn.out_proj.bias": "model-00004-of-00007.safetensors",
|
49 |
+
"transformer.blocks.12.attn.out_proj.weight": "model-00004-of-00007.safetensors",
|
50 |
+
"transformer.blocks.12.ffn.down_proj.bias": "model-00004-of-00007.safetensors",
|
51 |
+
"transformer.blocks.12.ffn.down_proj.weight": "model-00004-of-00007.safetensors",
|
52 |
+
"transformer.blocks.12.ffn.up_proj.bias": "model-00004-of-00007.safetensors",
|
53 |
+
"transformer.blocks.12.ffn.up_proj.weight": "model-00004-of-00007.safetensors",
|
54 |
+
"transformer.blocks.12.norm_1.weight": "model-00004-of-00007.safetensors",
|
55 |
+
"transformer.blocks.12.norm_2.weight": "model-00004-of-00007.safetensors",
|
56 |
+
"transformer.blocks.13.attn.Wqkv.bias": "model-00004-of-00007.safetensors",
|
57 |
+
"transformer.blocks.13.attn.Wqkv.weight": "model-00004-of-00007.safetensors",
|
58 |
+
"transformer.blocks.13.attn.out_proj.bias": "model-00004-of-00007.safetensors",
|
59 |
+
"transformer.blocks.13.attn.out_proj.weight": "model-00004-of-00007.safetensors",
|
60 |
+
"transformer.blocks.13.ffn.down_proj.bias": "model-00004-of-00007.safetensors",
|
61 |
+
"transformer.blocks.13.ffn.down_proj.weight": "model-00004-of-00007.safetensors",
|
62 |
+
"transformer.blocks.13.ffn.up_proj.bias": "model-00004-of-00007.safetensors",
|
63 |
+
"transformer.blocks.13.ffn.up_proj.weight": "model-00004-of-00007.safetensors",
|
64 |
+
"transformer.blocks.13.norm_1.weight": "model-00004-of-00007.safetensors",
|
65 |
+
"transformer.blocks.13.norm_2.weight": "model-00004-of-00007.safetensors",
|
66 |
+
"transformer.blocks.14.attn.Wqkv.bias": "model-00004-of-00007.safetensors",
|
67 |
+
"transformer.blocks.14.attn.Wqkv.weight": "model-00004-of-00007.safetensors",
|
68 |
+
"transformer.blocks.14.attn.out_proj.bias": "model-00004-of-00007.safetensors",
|
69 |
+
"transformer.blocks.14.attn.out_proj.weight": "model-00004-of-00007.safetensors",
|
70 |
+
"transformer.blocks.14.ffn.down_proj.bias": "model-00005-of-00007.safetensors",
|
71 |
+
"transformer.blocks.14.ffn.down_proj.weight": "model-00005-of-00007.safetensors",
|
72 |
+
"transformer.blocks.14.ffn.up_proj.bias": "model-00004-of-00007.safetensors",
|
73 |
+
"transformer.blocks.14.ffn.up_proj.weight": "model-00004-of-00007.safetensors",
|
74 |
+
"transformer.blocks.14.norm_1.weight": "model-00004-of-00007.safetensors",
|
75 |
+
"transformer.blocks.14.norm_2.weight": "model-00004-of-00007.safetensors",
|
76 |
+
"transformer.blocks.15.attn.Wqkv.bias": "model-00005-of-00007.safetensors",
|
77 |
+
"transformer.blocks.15.attn.Wqkv.weight": "model-00005-of-00007.safetensors",
|
78 |
+
"transformer.blocks.15.attn.out_proj.bias": "model-00005-of-00007.safetensors",
|
79 |
+
"transformer.blocks.15.attn.out_proj.weight": "model-00005-of-00007.safetensors",
|
80 |
+
"transformer.blocks.15.ffn.down_proj.bias": "model-00005-of-00007.safetensors",
|
81 |
+
"transformer.blocks.15.ffn.down_proj.weight": "model-00005-of-00007.safetensors",
|
82 |
+
"transformer.blocks.15.ffn.up_proj.bias": "model-00005-of-00007.safetensors",
|
83 |
+
"transformer.blocks.15.ffn.up_proj.weight": "model-00005-of-00007.safetensors",
|
84 |
+
"transformer.blocks.15.norm_1.weight": "model-00005-of-00007.safetensors",
|
85 |
+
"transformer.blocks.15.norm_2.weight": "model-00005-of-00007.safetensors",
|
86 |
+
"transformer.blocks.16.attn.Wqkv.bias": "model-00005-of-00007.safetensors",
|
87 |
+
"transformer.blocks.16.attn.Wqkv.weight": "model-00005-of-00007.safetensors",
|
88 |
+
"transformer.blocks.16.attn.out_proj.bias": "model-00005-of-00007.safetensors",
|
89 |
+
"transformer.blocks.16.attn.out_proj.weight": "model-00005-of-00007.safetensors",
|
90 |
+
"transformer.blocks.16.ffn.down_proj.bias": "model-00005-of-00007.safetensors",
|
91 |
+
"transformer.blocks.16.ffn.down_proj.weight": "model-00005-of-00007.safetensors",
|
92 |
+
"transformer.blocks.16.ffn.up_proj.bias": "model-00005-of-00007.safetensors",
|
93 |
+
"transformer.blocks.16.ffn.up_proj.weight": "model-00005-of-00007.safetensors",
|
94 |
+
"transformer.blocks.16.norm_1.weight": "model-00005-of-00007.safetensors",
|
95 |
+
"transformer.blocks.16.norm_2.weight": "model-00005-of-00007.safetensors",
|
96 |
+
"transformer.blocks.17.attn.Wqkv.bias": "model-00005-of-00007.safetensors",
|
97 |
+
"transformer.blocks.17.attn.Wqkv.weight": "model-00005-of-00007.safetensors",
|
98 |
+
"transformer.blocks.17.attn.out_proj.bias": "model-00005-of-00007.safetensors",
|
99 |
+
"transformer.blocks.17.attn.out_proj.weight": "model-00005-of-00007.safetensors",
|
100 |
+
"transformer.blocks.17.ffn.down_proj.bias": "model-00005-of-00007.safetensors",
|
101 |
+
"transformer.blocks.17.ffn.down_proj.weight": "model-00005-of-00007.safetensors",
|
102 |
+
"transformer.blocks.17.ffn.up_proj.bias": "model-00005-of-00007.safetensors",
|
103 |
+
"transformer.blocks.17.ffn.up_proj.weight": "model-00005-of-00007.safetensors",
|
104 |
+
"transformer.blocks.17.norm_1.weight": "model-00005-of-00007.safetensors",
|
105 |
+
"transformer.blocks.17.norm_2.weight": "model-00005-of-00007.safetensors",
|
106 |
+
"transformer.blocks.18.attn.Wqkv.bias": "model-00005-of-00007.safetensors",
|
107 |
+
"transformer.blocks.18.attn.Wqkv.weight": "model-00005-of-00007.safetensors",
|
108 |
+
"transformer.blocks.18.attn.out_proj.bias": "model-00005-of-00007.safetensors",
|
109 |
+
"transformer.blocks.18.attn.out_proj.weight": "model-00005-of-00007.safetensors",
|
110 |
+
"transformer.blocks.18.ffn.down_proj.bias": "model-00005-of-00007.safetensors",
|
111 |
+
"transformer.blocks.18.ffn.down_proj.weight": "model-00005-of-00007.safetensors",
|
112 |
+
"transformer.blocks.18.ffn.up_proj.bias": "model-00005-of-00007.safetensors",
|
113 |
+
"transformer.blocks.18.ffn.up_proj.weight": "model-00005-of-00007.safetensors",
|
114 |
+
"transformer.blocks.18.norm_1.weight": "model-00005-of-00007.safetensors",
|
115 |
+
"transformer.blocks.18.norm_2.weight": "model-00005-of-00007.safetensors",
|
116 |
+
"transformer.blocks.19.attn.Wqkv.bias": "model-00005-of-00007.safetensors",
|
117 |
+
"transformer.blocks.19.attn.Wqkv.weight": "model-00005-of-00007.safetensors",
|
118 |
+
"transformer.blocks.19.attn.out_proj.bias": "model-00005-of-00007.safetensors",
|
119 |
+
"transformer.blocks.19.attn.out_proj.weight": "model-00005-of-00007.safetensors",
|
120 |
+
"transformer.blocks.19.ffn.down_proj.bias": "model-00006-of-00007.safetensors",
|
121 |
+
"transformer.blocks.19.ffn.down_proj.weight": "model-00006-of-00007.safetensors",
|
122 |
+
"transformer.blocks.19.ffn.up_proj.bias": "model-00006-of-00007.safetensors",
|
123 |
+
"transformer.blocks.19.ffn.up_proj.weight": "model-00006-of-00007.safetensors",
|
124 |
+
"transformer.blocks.19.norm_1.weight": "model-00005-of-00007.safetensors",
|
125 |
+
"transformer.blocks.19.norm_2.weight": "model-00005-of-00007.safetensors",
|
126 |
+
"transformer.blocks.2.attn.Wqkv.bias": "model-00002-of-00007.safetensors",
|
127 |
+
"transformer.blocks.2.attn.Wqkv.weight": "model-00002-of-00007.safetensors",
|
128 |
+
"transformer.blocks.2.attn.out_proj.bias": "model-00002-of-00007.safetensors",
|
129 |
+
"transformer.blocks.2.attn.out_proj.weight": "model-00002-of-00007.safetensors",
|
130 |
+
"transformer.blocks.2.ffn.down_proj.bias": "model-00002-of-00007.safetensors",
|
131 |
+
"transformer.blocks.2.ffn.down_proj.weight": "model-00002-of-00007.safetensors",
|
132 |
+
"transformer.blocks.2.ffn.up_proj.bias": "model-00002-of-00007.safetensors",
|
133 |
+
"transformer.blocks.2.ffn.up_proj.weight": "model-00002-of-00007.safetensors",
|
134 |
+
"transformer.blocks.2.norm_1.weight": "model-00002-of-00007.safetensors",
|
135 |
+
"transformer.blocks.2.norm_2.weight": "model-00002-of-00007.safetensors",
|
136 |
+
"transformer.blocks.20.attn.Wqkv.bias": "model-00006-of-00007.safetensors",
|
137 |
+
"transformer.blocks.20.attn.Wqkv.weight": "model-00006-of-00007.safetensors",
|
138 |
+
"transformer.blocks.20.attn.out_proj.bias": "model-00006-of-00007.safetensors",
|
139 |
+
"transformer.blocks.20.attn.out_proj.weight": "model-00006-of-00007.safetensors",
|
140 |
+
"transformer.blocks.20.ffn.down_proj.bias": "model-00006-of-00007.safetensors",
|
141 |
+
"transformer.blocks.20.ffn.down_proj.weight": "model-00006-of-00007.safetensors",
|
142 |
+
"transformer.blocks.20.ffn.up_proj.bias": "model-00006-of-00007.safetensors",
|
143 |
+
"transformer.blocks.20.ffn.up_proj.weight": "model-00006-of-00007.safetensors",
|
144 |
+
"transformer.blocks.20.norm_1.weight": "model-00006-of-00007.safetensors",
|
145 |
+
"transformer.blocks.20.norm_2.weight": "model-00006-of-00007.safetensors",
|
146 |
+
"transformer.blocks.21.attn.Wqkv.bias": "model-00006-of-00007.safetensors",
|
147 |
+
"transformer.blocks.21.attn.Wqkv.weight": "model-00006-of-00007.safetensors",
|
148 |
+
"transformer.blocks.21.attn.out_proj.bias": "model-00006-of-00007.safetensors",
|
149 |
+
"transformer.blocks.21.attn.out_proj.weight": "model-00006-of-00007.safetensors",
|
150 |
+
"transformer.blocks.21.ffn.down_proj.bias": "model-00006-of-00007.safetensors",
|
151 |
+
"transformer.blocks.21.ffn.down_proj.weight": "model-00006-of-00007.safetensors",
|
152 |
+
"transformer.blocks.21.ffn.up_proj.bias": "model-00006-of-00007.safetensors",
|
153 |
+
"transformer.blocks.21.ffn.up_proj.weight": "model-00006-of-00007.safetensors",
|
154 |
+
"transformer.blocks.21.norm_1.weight": "model-00006-of-00007.safetensors",
|
155 |
+
"transformer.blocks.21.norm_2.weight": "model-00006-of-00007.safetensors",
|
156 |
+
"transformer.blocks.22.attn.Wqkv.bias": "model-00006-of-00007.safetensors",
|
157 |
+
"transformer.blocks.22.attn.Wqkv.weight": "model-00006-of-00007.safetensors",
|
158 |
+
"transformer.blocks.22.attn.out_proj.bias": "model-00006-of-00007.safetensors",
|
159 |
+
"transformer.blocks.22.attn.out_proj.weight": "model-00006-of-00007.safetensors",
|
160 |
+
"transformer.blocks.22.ffn.down_proj.bias": "model-00006-of-00007.safetensors",
|
161 |
+
"transformer.blocks.22.ffn.down_proj.weight": "model-00006-of-00007.safetensors",
|
162 |
+
"transformer.blocks.22.ffn.up_proj.bias": "model-00006-of-00007.safetensors",
|
163 |
+
"transformer.blocks.22.ffn.up_proj.weight": "model-00006-of-00007.safetensors",
|
164 |
+
"transformer.blocks.22.norm_1.weight": "model-00006-of-00007.safetensors",
|
165 |
+
"transformer.blocks.22.norm_2.weight": "model-00006-of-00007.safetensors",
|
166 |
+
"transformer.blocks.23.attn.Wqkv.bias": "model-00006-of-00007.safetensors",
|
167 |
+
"transformer.blocks.23.attn.Wqkv.weight": "model-00006-of-00007.safetensors",
|
168 |
+
"transformer.blocks.23.attn.out_proj.bias": "model-00006-of-00007.safetensors",
|
169 |
+
"transformer.blocks.23.attn.out_proj.weight": "model-00006-of-00007.safetensors",
|
170 |
+
"transformer.blocks.23.ffn.down_proj.bias": "model-00006-of-00007.safetensors",
|
171 |
+
"transformer.blocks.23.ffn.down_proj.weight": "model-00006-of-00007.safetensors",
|
172 |
+
"transformer.blocks.23.ffn.up_proj.bias": "model-00006-of-00007.safetensors",
|
173 |
+
"transformer.blocks.23.ffn.up_proj.weight": "model-00006-of-00007.safetensors",
|
174 |
+
"transformer.blocks.23.norm_1.weight": "model-00006-of-00007.safetensors",
|
175 |
+
"transformer.blocks.23.norm_2.weight": "model-00006-of-00007.safetensors",
|
176 |
+
"transformer.blocks.24.attn.Wqkv.bias": "model-00006-of-00007.safetensors",
|
177 |
+
"transformer.blocks.24.attn.Wqkv.weight": "model-00006-of-00007.safetensors",
|
178 |
+
"transformer.blocks.24.attn.out_proj.bias": "model-00007-of-00007.safetensors",
|
179 |
+
"transformer.blocks.24.attn.out_proj.weight": "model-00007-of-00007.safetensors",
|
180 |
+
"transformer.blocks.24.ffn.down_proj.bias": "model-00007-of-00007.safetensors",
|
181 |
+
"transformer.blocks.24.ffn.down_proj.weight": "model-00007-of-00007.safetensors",
|
182 |
+
"transformer.blocks.24.ffn.up_proj.bias": "model-00007-of-00007.safetensors",
|
183 |
+
"transformer.blocks.24.ffn.up_proj.weight": "model-00007-of-00007.safetensors",
|
184 |
+
"transformer.blocks.24.norm_1.weight": "model-00006-of-00007.safetensors",
|
185 |
+
"transformer.blocks.24.norm_2.weight": "model-00007-of-00007.safetensors",
|
186 |
+
"transformer.blocks.25.attn.Wqkv.bias": "model-00007-of-00007.safetensors",
|
187 |
+
"transformer.blocks.25.attn.Wqkv.weight": "model-00007-of-00007.safetensors",
|
188 |
+
"transformer.blocks.25.attn.out_proj.bias": "model-00007-of-00007.safetensors",
|
189 |
+
"transformer.blocks.25.attn.out_proj.weight": "model-00007-of-00007.safetensors",
|
190 |
+
"transformer.blocks.25.ffn.down_proj.bias": "model-00007-of-00007.safetensors",
|
191 |
+
"transformer.blocks.25.ffn.down_proj.weight": "model-00007-of-00007.safetensors",
|
192 |
+
"transformer.blocks.25.ffn.up_proj.bias": "model-00007-of-00007.safetensors",
|
193 |
+
"transformer.blocks.25.ffn.up_proj.weight": "model-00007-of-00007.safetensors",
|
194 |
+
"transformer.blocks.25.norm_1.weight": "model-00007-of-00007.safetensors",
|
195 |
+
"transformer.blocks.25.norm_2.weight": "model-00007-of-00007.safetensors",
|
196 |
+
"transformer.blocks.26.attn.Wqkv.bias": "model-00007-of-00007.safetensors",
|
197 |
+
"transformer.blocks.26.attn.Wqkv.weight": "model-00007-of-00007.safetensors",
|
198 |
+
"transformer.blocks.26.attn.out_proj.bias": "model-00007-of-00007.safetensors",
|
199 |
+
"transformer.blocks.26.attn.out_proj.weight": "model-00007-of-00007.safetensors",
|
200 |
+
"transformer.blocks.26.ffn.down_proj.bias": "model-00007-of-00007.safetensors",
|
201 |
+
"transformer.blocks.26.ffn.down_proj.weight": "model-00007-of-00007.safetensors",
|
202 |
+
"transformer.blocks.26.ffn.up_proj.bias": "model-00007-of-00007.safetensors",
|
203 |
+
"transformer.blocks.26.ffn.up_proj.weight": "model-00007-of-00007.safetensors",
|
204 |
+
"transformer.blocks.26.norm_1.weight": "model-00007-of-00007.safetensors",
|
205 |
+
"transformer.blocks.26.norm_2.weight": "model-00007-of-00007.safetensors",
|
206 |
+
"transformer.blocks.27.attn.Wqkv.bias": "model-00007-of-00007.safetensors",
|
207 |
+
"transformer.blocks.27.attn.Wqkv.weight": "model-00007-of-00007.safetensors",
|
208 |
+
"transformer.blocks.27.attn.out_proj.bias": "model-00007-of-00007.safetensors",
|
209 |
+
"transformer.blocks.27.attn.out_proj.weight": "model-00007-of-00007.safetensors",
|
210 |
+
"transformer.blocks.27.ffn.down_proj.bias": "model-00007-of-00007.safetensors",
|
211 |
+
"transformer.blocks.27.ffn.down_proj.weight": "model-00007-of-00007.safetensors",
|
212 |
+
"transformer.blocks.27.ffn.up_proj.bias": "model-00007-of-00007.safetensors",
|
213 |
+
"transformer.blocks.27.ffn.up_proj.weight": "model-00007-of-00007.safetensors",
|
214 |
+
"transformer.blocks.27.norm_1.weight": "model-00007-of-00007.safetensors",
|
215 |
+
"transformer.blocks.27.norm_2.weight": "model-00007-of-00007.safetensors",
|
216 |
+
"transformer.blocks.3.attn.Wqkv.bias": "model-00002-of-00007.safetensors",
|
217 |
+
"transformer.blocks.3.attn.Wqkv.weight": "model-00002-of-00007.safetensors",
|
218 |
+
"transformer.blocks.3.attn.out_proj.bias": "model-00002-of-00007.safetensors",
|
219 |
+
"transformer.blocks.3.attn.out_proj.weight": "model-00002-of-00007.safetensors",
|
220 |
+
"transformer.blocks.3.ffn.down_proj.bias": "model-00002-of-00007.safetensors",
|
221 |
+
"transformer.blocks.3.ffn.down_proj.weight": "model-00002-of-00007.safetensors",
|
222 |
+
"transformer.blocks.3.ffn.up_proj.bias": "model-00002-of-00007.safetensors",
|
223 |
+
"transformer.blocks.3.ffn.up_proj.weight": "model-00002-of-00007.safetensors",
|
224 |
+
"transformer.blocks.3.norm_1.weight": "model-00002-of-00007.safetensors",
|
225 |
+
"transformer.blocks.3.norm_2.weight": "model-00002-of-00007.safetensors",
|
226 |
+
"transformer.blocks.4.attn.Wqkv.bias": "model-00002-of-00007.safetensors",
|
227 |
+
"transformer.blocks.4.attn.Wqkv.weight": "model-00002-of-00007.safetensors",
|
228 |
+
"transformer.blocks.4.attn.out_proj.bias": "model-00002-of-00007.safetensors",
|
229 |
+
"transformer.blocks.4.attn.out_proj.weight": "model-00002-of-00007.safetensors",
|
230 |
+
"transformer.blocks.4.ffn.down_proj.bias": "model-00002-of-00007.safetensors",
|
231 |
+
"transformer.blocks.4.ffn.down_proj.weight": "model-00002-of-00007.safetensors",
|
232 |
+
"transformer.blocks.4.ffn.up_proj.bias": "model-00002-of-00007.safetensors",
|
233 |
+
"transformer.blocks.4.ffn.up_proj.weight": "model-00002-of-00007.safetensors",
|
234 |
+
"transformer.blocks.4.norm_1.weight": "model-00002-of-00007.safetensors",
|
235 |
+
"transformer.blocks.4.norm_2.weight": "model-00002-of-00007.safetensors",
|
236 |
+
"transformer.blocks.5.attn.Wqkv.bias": "model-00002-of-00007.safetensors",
|
237 |
+
"transformer.blocks.5.attn.Wqkv.weight": "model-00002-of-00007.safetensors",
|
238 |
+
"transformer.blocks.5.attn.out_proj.bias": "model-00003-of-00007.safetensors",
|
239 |
+
"transformer.blocks.5.attn.out_proj.weight": "model-00003-of-00007.safetensors",
|
240 |
+
"transformer.blocks.5.ffn.down_proj.bias": "model-00003-of-00007.safetensors",
|
241 |
+
"transformer.blocks.5.ffn.down_proj.weight": "model-00003-of-00007.safetensors",
|
242 |
+
"transformer.blocks.5.ffn.up_proj.bias": "model-00003-of-00007.safetensors",
|
243 |
+
"transformer.blocks.5.ffn.up_proj.weight": "model-00003-of-00007.safetensors",
|
244 |
+
"transformer.blocks.5.norm_1.weight": "model-00002-of-00007.safetensors",
|
245 |
+
"transformer.blocks.5.norm_2.weight": "model-00003-of-00007.safetensors",
|
246 |
+
"transformer.blocks.6.attn.Wqkv.bias": "model-00003-of-00007.safetensors",
|
247 |
+
"transformer.blocks.6.attn.Wqkv.weight": "model-00003-of-00007.safetensors",
|
248 |
+
"transformer.blocks.6.attn.out_proj.bias": "model-00003-of-00007.safetensors",
|
249 |
+
"transformer.blocks.6.attn.out_proj.weight": "model-00003-of-00007.safetensors",
|
250 |
+
"transformer.blocks.6.ffn.down_proj.bias": "model-00003-of-00007.safetensors",
|
251 |
+
"transformer.blocks.6.ffn.down_proj.weight": "model-00003-of-00007.safetensors",
|
252 |
+
"transformer.blocks.6.ffn.up_proj.bias": "model-00003-of-00007.safetensors",
|
253 |
+
"transformer.blocks.6.ffn.up_proj.weight": "model-00003-of-00007.safetensors",
|
254 |
+
"transformer.blocks.6.norm_1.weight": "model-00003-of-00007.safetensors",
|
255 |
+
"transformer.blocks.6.norm_2.weight": "model-00003-of-00007.safetensors",
|
256 |
+
"transformer.blocks.7.attn.Wqkv.bias": "model-00003-of-00007.safetensors",
|
257 |
+
"transformer.blocks.7.attn.Wqkv.weight": "model-00003-of-00007.safetensors",
|
258 |
+
"transformer.blocks.7.attn.out_proj.bias": "model-00003-of-00007.safetensors",
|
259 |
+
"transformer.blocks.7.attn.out_proj.weight": "model-00003-of-00007.safetensors",
|
260 |
+
"transformer.blocks.7.ffn.down_proj.bias": "model-00003-of-00007.safetensors",
|
261 |
+
"transformer.blocks.7.ffn.down_proj.weight": "model-00003-of-00007.safetensors",
|
262 |
+
"transformer.blocks.7.ffn.up_proj.bias": "model-00003-of-00007.safetensors",
|
263 |
+
"transformer.blocks.7.ffn.up_proj.weight": "model-00003-of-00007.safetensors",
|
264 |
+
"transformer.blocks.7.norm_1.weight": "model-00003-of-00007.safetensors",
|
265 |
+
"transformer.blocks.7.norm_2.weight": "model-00003-of-00007.safetensors",
|
266 |
+
"transformer.blocks.8.attn.Wqkv.bias": "model-00003-of-00007.safetensors",
|
267 |
+
"transformer.blocks.8.attn.Wqkv.weight": "model-00003-of-00007.safetensors",
|
268 |
+
"transformer.blocks.8.attn.out_proj.bias": "model-00003-of-00007.safetensors",
|
269 |
+
"transformer.blocks.8.attn.out_proj.weight": "model-00003-of-00007.safetensors",
|
270 |
+
"transformer.blocks.8.ffn.down_proj.bias": "model-00003-of-00007.safetensors",
|
271 |
+
"transformer.blocks.8.ffn.down_proj.weight": "model-00003-of-00007.safetensors",
|
272 |
+
"transformer.blocks.8.ffn.up_proj.bias": "model-00003-of-00007.safetensors",
|
273 |
+
"transformer.blocks.8.ffn.up_proj.weight": "model-00003-of-00007.safetensors",
|
274 |
+
"transformer.blocks.8.norm_1.weight": "model-00003-of-00007.safetensors",
|
275 |
+
"transformer.blocks.8.norm_2.weight": "model-00003-of-00007.safetensors",
|
276 |
+
"transformer.blocks.9.attn.Wqkv.bias": "model-00003-of-00007.safetensors",
|
277 |
+
"transformer.blocks.9.attn.Wqkv.weight": "model-00003-of-00007.safetensors",
|
278 |
+
"transformer.blocks.9.attn.out_proj.bias": "model-00003-of-00007.safetensors",
|
279 |
+
"transformer.blocks.9.attn.out_proj.weight": "model-00003-of-00007.safetensors",
|
280 |
+
"transformer.blocks.9.ffn.down_proj.bias": "model-00003-of-00007.safetensors",
|
281 |
+
"transformer.blocks.9.ffn.down_proj.weight": "model-00003-of-00007.safetensors",
|
282 |
+
"transformer.blocks.9.ffn.up_proj.bias": "model-00003-of-00007.safetensors",
|
283 |
+
"transformer.blocks.9.ffn.up_proj.weight": "model-00003-of-00007.safetensors",
|
284 |
+
"transformer.blocks.9.norm_1.weight": "model-00003-of-00007.safetensors",
|
285 |
+
"transformer.blocks.9.norm_2.weight": "model-00003-of-00007.safetensors",
|
286 |
+
"transformer.norm_f.weight": "model-00007-of-00007.safetensors",
|
287 |
+
"transformer.wte.weight": "model-00001-of-00007.safetensors"
|
288 |
+
}
|
289 |
+
}
|
modeling_mpt.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A simple, flexible implementation of a GPT model.
|
2 |
+
|
3 |
+
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
4 |
+
"""
|
5 |
+
from __future__ import annotations
|
6 |
+
from .dmoe import _UniformExpertAssignment
|
7 |
+
from .ffn import quickgelu_activation
|
8 |
+
from .config_defaults import *
|
9 |
+
from .registry_utils import TypedRegistry
|
10 |
+
from .warnings import VersionedDeprecationWarning
|
11 |
+
import copy
|
12 |
+
import math
|
13 |
+
import warnings
|
14 |
+
from functools import cached_property
|
15 |
+
from typing import Any, Mapping, MutableMapping, Optional, Union
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from tabulate import tabulate
|
20 |
+
from .layers_registry import ffns_with_megablocks
|
21 |
+
from .attention import is_flash_v2_installed
|
22 |
+
if is_flash_v2_installed():
|
23 |
+
try:
|
24 |
+
from flash_attn import bert_padding
|
25 |
+
from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
|
26 |
+
except Exception as e:
|
27 |
+
raise e
|
28 |
+
import logging
|
29 |
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
30 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
31 |
+
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding
|
32 |
+
from .layers_registry import norms, param_init_fns
|
33 |
+
from .attention import attn_bias_shape, build_attn_bias, gen_slopes
|
34 |
+
from .blocks import MPTBlock
|
35 |
+
from .custom_embedding import SharedEmbedding
|
36 |
+
from .layer_builders import build_norm
|
37 |
+
from .configuration_mpt import MPTConfig
|
38 |
+
from .act_ckpt import build_act_ckpt_mod_to_blocks, check_mapping_blocks_overlap, pass_on_block_idx
|
39 |
+
from .config_moe_args import config_moe_args
|
40 |
+
from .mpt_param_count import mpt_get_active_params, mpt_get_total_params
|
41 |
+
from .fc import fcs
|
42 |
+
from .param_init_fns import generic_param_init_fn_
|
43 |
+
from .norm import LPLayerNorm
|
44 |
+
log = logging.getLogger(__name__)
|
45 |
+
CROSS_ENTROPY_IGNORE_INDEX = -100
|
46 |
+
|
47 |
+
class InvalidConfigAccessError(KeyError):
|
48 |
+
pass
|
49 |
+
_ALLOWED_LLAMA_CONFIG_KEYS = {'rope_scaling', 'rope_theta', 'max_position_embeddings', 'hidden_size', 'num_attention_heads', 'partial_rotary_factor', 'head_dim', '_get_generation_defaults', 'label2id', 'id2label', 'torch_dtype', 'problem_type', '__class__', '_get_global_generation_defaults'}
|
50 |
+
|
51 |
+
class PartialLlamaConfig(LlamaConfig):
|
52 |
+
"""Holds the rope config for Llama models and throws.
|
53 |
+
|
54 |
+
an `InvalidConfigAccessError` if any other config elements are read. This
|
55 |
+
class is necessary because the `LlamaRotaryEmbedding` class takes a full
|
56 |
+
`LlamaConfig` now instead of the old keyword arguments.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __getattribute__(self, key: str):
|
60 |
+
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
|
61 |
+
raise InvalidConfigAccessError(key)
|
62 |
+
return super().__getattribute__(key)
|
63 |
+
|
64 |
+
def __getitem__(self, key: str):
|
65 |
+
if key not in _ALLOWED_LLAMA_CONFIG_KEYS:
|
66 |
+
raise InvalidConfigAccessError(key)
|
67 |
+
return super().__getitem__(key)
|
68 |
+
|
69 |
+
def gen_rotary_embedding(rope_impl: str, rope_theta: int, rope_dail_config: dict, rope_hf_config: dict, max_seq_len: int, d_model: int, n_heads: int):
|
70 |
+
rope_head_dim = d_model // n_heads
|
71 |
+
if rope_impl == 'dail':
|
72 |
+
return DAILRotaryEmbedding(dim=rope_head_dim, base=rope_theta, interleaved=False, scale_base=rope_dail_config['xpos_scale_base'] if rope_dail_config['type'] == 'xpos' else None, pos_idx_in_fp32=rope_dail_config['pos_idx_in_fp32'], device='cpu')
|
73 |
+
elif rope_impl == 'hf':
|
74 |
+
llama_rope_config = {**rope_hf_config}
|
75 |
+
llama_rope_config['rope_type'] = llama_rope_config.pop('type')
|
76 |
+
if llama_rope_config['rope_type'] == 'no_scaling':
|
77 |
+
llama_rope_config['rope_type'] = 'default'
|
78 |
+
partial_llama_config = PartialLlamaConfig(rope_scaling=llama_rope_config, rope_theta=rope_theta, max_position_embeddings=max_seq_len, hidden_size=d_model, num_attention_heads=n_heads)
|
79 |
+
return LlamaRotaryEmbeddingFoundry(config=partial_llama_config)
|
80 |
+
raise ValueError('rope_impl needs to be either dail or hf')
|
81 |
+
|
82 |
+
def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, attn_uses_sequence_id: bool, attn_impl: str, attention_mask: Union[torch.Tensor, None]):
|
83 |
+
"""Generates the attention mask used for sequence masking in FA v2.
|
84 |
+
|
85 |
+
Only supports sequence id based sparse attention for no attention masking or attention masking with right padding.
|
86 |
+
In case of left padding:
|
87 |
+
1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407).
|
88 |
+
2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len).
|
92 |
+
S (int): Sequence length
|
93 |
+
attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking.
|
94 |
+
attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention.
|
95 |
+
attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len)
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
|
99 |
+
```
|
100 |
+
[
|
101 |
+
[2, 3, 0, 0, 0, 0],
|
102 |
+
[3, 2, 0, 0, 0, 0],
|
103 |
+
[6, 0, 0, 0, 0, 0]
|
104 |
+
]
|
105 |
+
```
|
106 |
+
, which refers to the 3D-attention mask:
|
107 |
+
```
|
108 |
+
[
|
109 |
+
[
|
110 |
+
[1, 0, 0, 0, 0, 0],
|
111 |
+
[1, 1, 0, 0, 0, 0],
|
112 |
+
[0, 0, 1, 0, 0, 0],
|
113 |
+
[0, 0, 1, 1, 0, 0],
|
114 |
+
[0, 0, 1, 1, 1, 0],
|
115 |
+
[0, 0, 0, 0, 0, 1]
|
116 |
+
],
|
117 |
+
[
|
118 |
+
[1, 0, 0, 0, 0, 0],
|
119 |
+
[1, 1, 0, 0, 0, 0],
|
120 |
+
[1, 1, 1, 0, 0, 0],
|
121 |
+
[0, 0, 0, 1, 0, 0],
|
122 |
+
[0, 0, 0, 1, 1, 0],
|
123 |
+
[0, 0, 0, 0, 0, 1]
|
124 |
+
],
|
125 |
+
[
|
126 |
+
[1, 0, 0, 0, 0, 0],
|
127 |
+
[1, 1, 0, 0, 0, 0],
|
128 |
+
[1, 1, 1, 0, 0, 0],
|
129 |
+
[1, 1, 1, 1, 0, 0],
|
130 |
+
[1, 1, 1, 1, 1, 0],
|
131 |
+
[1, 1, 1, 1, 1, 1]
|
132 |
+
]
|
133 |
+
]
|
134 |
+
```.
|
135 |
+
(The description above is taken verbatim from https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/bert_padding.py#L125 .)
|
136 |
+
"""
|
137 |
+
attention_mask_in_length = None
|
138 |
+
if sequence_id is not None and attn_uses_sequence_id and (attn_impl == 'flash'):
|
139 |
+
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0]:
|
140 |
+
raise NotImplementedError('Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.')
|
141 |
+
if S != sequence_id.shape[-1]:
|
142 |
+
raise ValueError(f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).')
|
143 |
+
if attention_mask is not None:
|
144 |
+
sequence_id = sequence_id.masked_fill(~attention_mask, 0)
|
145 |
+
attention_mask_in_length = torch.nn.functional.one_hot(sequence_id)
|
146 |
+
if attention_mask is not None:
|
147 |
+
attention_mask_in_length = attention_mask_in_length.masked_fill(~attention_mask.unsqueeze(-1), 0)
|
148 |
+
attention_mask_in_length = attention_mask_in_length.sum(dim=1)
|
149 |
+
attention_mask_in_length = torch.nn.functional.pad(attention_mask_in_length, (0, S - attention_mask_in_length.shape[-1]), mode='constant', value=0)
|
150 |
+
return attention_mask_in_length
|
151 |
+
|
152 |
+
def gen_flash_attn_padding_info(bsz: int, S: int, past_key_len: int, device: torch.device, attention_mask_in_length: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None):
|
153 |
+
flash_attn_padding_info = {}
|
154 |
+
if attention_mask_in_length is None:
|
155 |
+
key_padding_mask = attention_mask
|
156 |
+
if key_padding_mask is None:
|
157 |
+
key_padding_mask = torch.ones((bsz, past_key_len + S), dtype=torch.bool, device=device)
|
158 |
+
query_padding_mask = key_padding_mask[:, -S:]
|
159 |
+
unpadding_function = bert_padding.unpad_input
|
160 |
+
else:
|
161 |
+
key_padding_mask = attention_mask_in_length
|
162 |
+
query_padding_mask = attention_mask_in_length
|
163 |
+
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
|
164 |
+
_, indices_q, cu_seqlens_q, max_seqlen_q, *_ = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
|
165 |
+
_, indices_k, cu_seqlens_k, max_seqlen_k, *_ = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
166 |
+
_, indices_v, *_ = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
167 |
+
flash_attn_padding_info['indices_q'] = indices_q
|
168 |
+
flash_attn_padding_info['indices_k'] = indices_k
|
169 |
+
flash_attn_padding_info['indices_v'] = indices_v
|
170 |
+
flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q
|
171 |
+
flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k
|
172 |
+
flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q
|
173 |
+
flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k
|
174 |
+
return flash_attn_padding_info
|
175 |
+
|
176 |
+
def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor, max_seq_len: int) -> torch.Tensor:
|
177 |
+
seq_len = sequence_id.shape[-1]
|
178 |
+
if seq_len > max_seq_len:
|
179 |
+
raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={max_seq_len}')
|
180 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
181 |
+
cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
|
182 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
183 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
184 |
+
return attn_bias
|
185 |
+
|
186 |
+
class LlamaRotaryEmbeddingFoundry(LlamaRotaryEmbedding):
|
187 |
+
|
188 |
+
@torch.no_grad()
|
189 |
+
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
190 |
+
self.inv_freq = self.inv_freq.to(position_ids.device)
|
191 |
+
return super().forward(x=x, position_ids=position_ids)
|
192 |
+
|
193 |
+
class MPTPreTrainedModel(PreTrainedModel):
|
194 |
+
config_class = MPTConfig
|
195 |
+
base_model_prefix = 'model'
|
196 |
+
_no_split_modules = ['MPTBlock']
|
197 |
+
|
198 |
+
def _fsdp_wrap_fn(self: Union[MPTModel, MPTForCausalLM], module: nn.Module) -> bool:
|
199 |
+
if hasattr(module, '_fsdp_kwargs_dict'):
|
200 |
+
return module._fsdp_kwargs_dict
|
201 |
+
return isinstance(module, MPTBlock)
|
202 |
+
|
203 |
+
class MPTModel(MPTPreTrainedModel):
|
204 |
+
|
205 |
+
def __init__(self, config: MPTConfig):
|
206 |
+
config._validate_config()
|
207 |
+
super().__init__(config)
|
208 |
+
self.attn_impl = config.attn_config['attn_impl']
|
209 |
+
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
210 |
+
self.alibi = config.attn_config['alibi']
|
211 |
+
self.alibi_bias_max = config.attn_config['alibi_bias_max']
|
212 |
+
self.learned_pos_emb = config.learned_pos_emb
|
213 |
+
if config.init_device == 'mixed':
|
214 |
+
if dist.get_local_rank() == 0:
|
215 |
+
config.init_device = 'cpu'
|
216 |
+
else:
|
217 |
+
config.init_device = 'meta'
|
218 |
+
if config.norm_type.lower() not in norms.get_all():
|
219 |
+
norm_options = ' | '.join(norms.get_all())
|
220 |
+
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
221 |
+
self.embedding_fraction = config.embedding_fraction
|
222 |
+
self.wte = SharedEmbedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id, device=config.init_device)
|
223 |
+
if self.learned_pos_emb:
|
224 |
+
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
225 |
+
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
226 |
+
self.mb_args = None
|
227 |
+
self.shift_labels = True
|
228 |
+
self.blocks = self.construct_blocks(config=config)
|
229 |
+
for i, block in enumerate(self.blocks):
|
230 |
+
block.block_idx = i
|
231 |
+
block.max_block_idx = config.n_layers - 1
|
232 |
+
pass_on_block_idx(block)
|
233 |
+
self.norm_f = build_norm(name=config.norm_type.lower(), normalized_shape=config.d_model, eps=config.norm_eps, device=config.init_device)
|
234 |
+
self.rope = config.attn_config['rope']
|
235 |
+
self.rope_impl = None
|
236 |
+
if self.rope:
|
237 |
+
self.rope_impl = config.attn_config['rope_impl']
|
238 |
+
self.rotary_embedding = gen_rotary_embedding(rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], max_seq_len=self.config.max_seq_len, d_model=config.d_model, n_heads=config.n_heads)
|
239 |
+
if config.init_device != 'meta':
|
240 |
+
log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
|
241 |
+
self.apply(self.param_init_fn)
|
242 |
+
self.is_causal = True
|
243 |
+
self._attn_bias_initialized = False
|
244 |
+
self.attn_bias = None
|
245 |
+
self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
|
246 |
+
if config.no_bias:
|
247 |
+
for module in self.modules():
|
248 |
+
if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
|
249 |
+
log.debug(f'Removing bias from module={module!r}.')
|
250 |
+
module.register_parameter('bias', None)
|
251 |
+
if hasattr(module, 'use_bias') and module.use_bias is True:
|
252 |
+
log.debug(f'Setting use_bias=False for module={module!r}.')
|
253 |
+
module.use_bias = False
|
254 |
+
log.debug(self)
|
255 |
+
init_config_name = self.config.init_config['name']
|
256 |
+
log.debug(f'Using {init_config_name} initialization.')
|
257 |
+
|
258 |
+
@property
|
259 |
+
def block_class(self) -> type[MPTBlock]:
|
260 |
+
return MPTBlock
|
261 |
+
|
262 |
+
def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
|
263 |
+
"""Construct the nn.ModuleList with the Transformer blocks.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
config (MPTConfig): The configuration object.
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
nn.ModuleList: The list of Transformer blocks.
|
270 |
+
"""
|
271 |
+
block_args = self.extract_block_args(config.to_dict())
|
272 |
+
self.kv_cache_layers = set()
|
273 |
+
self.blocks_fuse_norm_attn_norm = block_args.get('fuse_norm_attn_norm', False)
|
274 |
+
if config.block_overrides is not None:
|
275 |
+
block_args_list = self._get_override_block_args_list(config, block_args)
|
276 |
+
else:
|
277 |
+
block_args_list = [block_args for _ in range(config.n_layers)]
|
278 |
+
return nn.ModuleList([self.block_class(device=config.init_device, **block_args_i) for block_args_i in block_args_list])
|
279 |
+
|
280 |
+
def _get_override_block_args_list(self, config: MPTConfig, block_args: dict[str, Any]) -> list[dict[str, Any]]:
|
281 |
+
if config.block_overrides is None:
|
282 |
+
raise ValueError('config.block_overrides should not be None when calling _get_override_block_args_list.')
|
283 |
+
repeat = config.block_overrides.get('repeat', 1)
|
284 |
+
model_modules_order_expanded = MPTModel._get_modules_order_expanded(config.block_overrides['order']) * repeat
|
285 |
+
if len(model_modules_order_expanded) != config.n_layers:
|
286 |
+
raise ValueError(f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.')
|
287 |
+
new_block_args_list = []
|
288 |
+
layer_description_list = []
|
289 |
+
reuse_kv_layer_idx_dict = {}
|
290 |
+
for b_idx in range(config.n_layers):
|
291 |
+
module_name = model_modules_order_expanded[b_idx]
|
292 |
+
override_config = {}
|
293 |
+
if module_name != 'default':
|
294 |
+
override_config = copy.deepcopy(config.block_overrides['overrides'][module_name])
|
295 |
+
if 'reuse_kv_layer_idx' in override_config.get('attn_config', {}):
|
296 |
+
reuse_kv_layer_idx = MPTModel._resolve_reuse_kv_layer_idx(overrides_definition=config.block_overrides['overrides'], model_modules_order_expanded=model_modules_order_expanded, b_idx=b_idx, override_config=override_config, reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict)
|
297 |
+
override_config['attn_config']['reuse_kv_layer_idx'] = reuse_kv_layer_idx
|
298 |
+
self.kv_cache_layers.add(reuse_kv_layer_idx)
|
299 |
+
layer_description_list.append([b_idx, module_name, override_config])
|
300 |
+
new_block_args_list.append(MPTModel._override_block_args(block_args, override_config, config.allowed_block_overrides))
|
301 |
+
log.info('The following is a summary of overrides per layer.\n' + tabulate(layer_description_list, headers=['idx', 'name', 'overrides']))
|
302 |
+
return new_block_args_list
|
303 |
+
|
304 |
+
@staticmethod
|
305 |
+
def _resolve_reuse_kv_layer_idx(overrides_definition: dict[str, Any], model_modules_order_expanded: list[str], b_idx: int, override_config: dict[str, Any], reuse_kv_layer_idx_dict: dict[int, int]) -> int:
|
306 |
+
override_attn_config = override_config['attn_config']
|
307 |
+
if override_attn_config['reuse_kv_layer_idx'] >= 0:
|
308 |
+
reuse_kv_layer_idx = override_attn_config['reuse_kv_layer_idx']
|
309 |
+
raise ValueError(f"The relative index of kv layer to reuse, override_attn_config['reuse_kv_layer_idx']={reuse_kv_layer_idx}, should be negative.")
|
310 |
+
reuse_kv_layer_idx = b_idx + override_attn_config['reuse_kv_layer_idx']
|
311 |
+
if reuse_kv_layer_idx < 0:
|
312 |
+
raise ValueError(f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.')
|
313 |
+
if reuse_kv_layer_idx in reuse_kv_layer_idx_dict:
|
314 |
+
reuse_kv_layer_idx = reuse_kv_layer_idx_dict[reuse_kv_layer_idx]
|
315 |
+
reuse_kv_layer_idx_dict[b_idx] = reuse_kv_layer_idx
|
316 |
+
parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx]
|
317 |
+
parent_config = {} if parent_layer_name == 'default' else copy.deepcopy(overrides_definition[parent_layer_name])
|
318 |
+
if 'attn_config' not in parent_config:
|
319 |
+
parent_config['attn_config'] = {}
|
320 |
+
parent_config['attn_config']['reuse_kv_layer_idx'] = override_config['attn_config']['reuse_kv_layer_idx']
|
321 |
+
if override_config != parent_config and (not ('allow_mismatch' in override_config and override_config['allow_mismatch'])):
|
322 |
+
raise ValueError('For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.')
|
323 |
+
return reuse_kv_layer_idx
|
324 |
+
|
325 |
+
@staticmethod
|
326 |
+
def _get_modules_order_expanded(order: list[dict[str, Any]]) -> list[str]:
|
327 |
+
model_modules_order_expanded = []
|
328 |
+
for item in order:
|
329 |
+
repeat = item['repeat'] if 'repeat' in item else 1
|
330 |
+
if ('name' in item) == ('order' in item):
|
331 |
+
raise ValueError('Exactly one of `order` or `name` must be specified for each block override.')
|
332 |
+
if 'name' in item:
|
333 |
+
model_modules_order_expanded.extend([item['name']] * repeat)
|
334 |
+
else:
|
335 |
+
model_modules_order_expanded.extend(MPTModel._get_modules_order_expanded(item['order']) * repeat)
|
336 |
+
return model_modules_order_expanded
|
337 |
+
|
338 |
+
@staticmethod
|
339 |
+
def _override_block_args(block_args: dict[str, Any], override_config: dict[str, Any], allowed_block_overrides: dict[str, Any]) -> dict[str, Any]:
|
340 |
+
unpermitted_keys = override_config.keys() - allowed_block_overrides.keys()
|
341 |
+
if len(unpermitted_keys):
|
342 |
+
raise KeyError(f'Overriding {unpermitted_keys} is not supported.')
|
343 |
+
new_block_args = override_config | block_args
|
344 |
+
common_keys = override_config.keys() & block_args.keys()
|
345 |
+
for k in common_keys:
|
346 |
+
if type(override_config[k]) != type(block_args[k]):
|
347 |
+
raise ValueError(f'Override config should have same value types as the original config. Found override_config[{k}]={override_config[k]} vs block_args[{k}]={block_args[k]}.')
|
348 |
+
if isinstance(override_config[k], dict):
|
349 |
+
new_block_args[k] = MPTModel._override_block_args(block_args[k], override_config[k], allowed_block_overrides[k])
|
350 |
+
else:
|
351 |
+
new_block_args[k] = override_config[k]
|
352 |
+
return new_block_args
|
353 |
+
|
354 |
+
def extract_block_args(self, block_args: dict[str, Any]) -> dict[str, Any]:
|
355 |
+
"""Sets the block args."""
|
356 |
+
if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks:
|
357 |
+
block_args['ffn_config'] = config_moe_args(block_args['ffn_config'], block_args['d_model'], block_args['expansion_ratio'], block_args['n_layers'])
|
358 |
+
self.mb_args = block_args['ffn_config'].get('args')
|
359 |
+
return block_args
|
360 |
+
|
361 |
+
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
|
362 |
+
return self.wte
|
363 |
+
|
364 |
+
def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
|
365 |
+
self.wte = value
|
366 |
+
|
367 |
+
@torch.no_grad()
|
368 |
+
def _attn_bias(self, device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None) -> tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
|
369 |
+
if not self._attn_bias_initialized:
|
370 |
+
if self.attn_bias_shape:
|
371 |
+
self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
|
372 |
+
self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)
|
373 |
+
self._attn_bias_initialized = True
|
374 |
+
if self.attn_impl == 'flash':
|
375 |
+
return (self.attn_bias, attention_mask)
|
376 |
+
if self.attn_bias is not None:
|
377 |
+
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
378 |
+
attn_bias = self.attn_bias
|
379 |
+
if self.attn_uses_sequence_id and sequence_id is not None:
|
380 |
+
assert isinstance(attn_bias, torch.Tensor)
|
381 |
+
attn_bias = apply_sequence_id(attn_bias, sequence_id, self.config.max_seq_len)
|
382 |
+
if attention_mask is not None:
|
383 |
+
s_k = attention_mask.shape[-1]
|
384 |
+
if attn_bias is None:
|
385 |
+
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
386 |
+
else:
|
387 |
+
_s_k = max(0, attn_bias.size(-1) - s_k)
|
388 |
+
attn_bias = attn_bias[:, :, :, _s_k:]
|
389 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
390 |
+
attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
|
391 |
+
return (attn_bias, attention_mask)
|
392 |
+
|
393 |
+
def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[list[tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None, position_ids: Optional[torch.LongTensor]=None) -> BaseModelOutputWithPast:
|
394 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
395 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
396 |
+
if attention_mask is not None:
|
397 |
+
attention_mask = attention_mask.bool()
|
398 |
+
if not return_dict:
|
399 |
+
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
400 |
+
if output_attentions:
|
401 |
+
if self.attn_impl != 'torch':
|
402 |
+
raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash`.')
|
403 |
+
if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
|
404 |
+
raise NotImplementedError('MPT does not support training with left padding.')
|
405 |
+
if self.training:
|
406 |
+
if self.attn_uses_sequence_id and sequence_id is None:
|
407 |
+
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
|
408 |
+
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
409 |
+
warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
|
410 |
+
if input_ids is not None and inputs_embeds is not None:
|
411 |
+
raise ValueError('You cannot specify both input_ids and inputs_embeds.')
|
412 |
+
elif input_ids is not None:
|
413 |
+
bsz = input_ids.size(0)
|
414 |
+
x = self.wte(input_ids)
|
415 |
+
input_device = input_ids.device
|
416 |
+
elif inputs_embeds is not None:
|
417 |
+
bsz = inputs_embeds.size(0)
|
418 |
+
x = inputs_embeds
|
419 |
+
input_device = inputs_embeds.device
|
420 |
+
else:
|
421 |
+
raise ValueError('You must specify input_ids or inputs_embeds')
|
422 |
+
S = self.get_sequence_length(x)
|
423 |
+
assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
|
424 |
+
rotary_emb_w_meta_info = None
|
425 |
+
past_position = 0
|
426 |
+
if past_key_values is not None:
|
427 |
+
if len(past_key_values) != self.config.n_layers:
|
428 |
+
raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
|
429 |
+
past_position = past_key_values[0][0].size(1)
|
430 |
+
if self.attn_impl == 'torch':
|
431 |
+
past_position = past_key_values[0][0].size(3)
|
432 |
+
if self.learned_pos_emb or self.rope:
|
433 |
+
if self.learned_pos_emb and S + past_position > self.config.max_seq_len:
|
434 |
+
raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
|
435 |
+
if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'):
|
436 |
+
if position_ids is None:
|
437 |
+
pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_device).unsqueeze(0)
|
438 |
+
else:
|
439 |
+
pos = position_ids
|
440 |
+
if attention_mask is not None:
|
441 |
+
pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
|
442 |
+
if self.learned_pos_emb:
|
443 |
+
x = x + self.wpe(pos)
|
444 |
+
elif self.rope and self.rope_impl == 'hf':
|
445 |
+
rotary_emb_w_meta_info = {'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': pos, 'seq_len': S + past_position}
|
446 |
+
elif self.rope and self.rope_impl == 'dail':
|
447 |
+
rotary_emb_w_meta_info = {'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': past_position, 'seq_len': S + past_position}
|
448 |
+
if self.embedding_fraction == 1:
|
449 |
+
x = self.emb_drop(x)
|
450 |
+
else:
|
451 |
+
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
452 |
+
assert isinstance(self.emb_drop, nn.Module)
|
453 |
+
x = self.emb_drop(x_shrunk)
|
454 |
+
attn_bias, attention_mask = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, sequence_id=sequence_id)
|
455 |
+
attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask)
|
456 |
+
alibi_slopes = None
|
457 |
+
if self.alibi and self.attn_impl == 'flash':
|
458 |
+
alibi_slopes = gen_slopes(n_heads=self.config.n_heads, alibi_bias_max=self.alibi_bias_max, device=x.device, return_1d=True)
|
459 |
+
presents = () if use_cache else None
|
460 |
+
if (use_cache or len(self.kv_cache_layers) > 0) and past_key_values is None:
|
461 |
+
past_key_values = [() for _ in range(self.config.n_layers)]
|
462 |
+
all_hidden_states = () if output_hidden_states else None
|
463 |
+
all_self_attns = () if output_attentions else None
|
464 |
+
flash_attn_padding_info = {}
|
465 |
+
if self.attn_impl == 'flash':
|
466 |
+
flash_attn_padding_info = gen_flash_attn_padding_info(bsz, S, past_position, x.device, attention_mask_in_length, attention_mask)
|
467 |
+
layer_kv_cache_dict = {}
|
468 |
+
for b_idx, block in enumerate(self.blocks):
|
469 |
+
attn_block = block.norm_attn_norm.attn if self.blocks_fuse_norm_attn_norm else block.attn
|
470 |
+
if attn_block.reuse_kv_layer_idx is not None:
|
471 |
+
if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict:
|
472 |
+
raise KeyError(f'kv cache for layer {block.reuse_kv_layer_idx} not found in layer_kv_cache_dict={layer_kv_cache_dict!r}.')
|
473 |
+
prev_layer_key_value = layer_kv_cache_dict[attn_block.reuse_kv_layer_idx]
|
474 |
+
else:
|
475 |
+
prev_layer_key_value = None
|
476 |
+
if output_hidden_states:
|
477 |
+
assert all_hidden_states is not None
|
478 |
+
all_hidden_states = all_hidden_states + (x,)
|
479 |
+
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
480 |
+
extra_kwargs = {}
|
481 |
+
if prev_layer_key_value is not None:
|
482 |
+
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
|
483 |
+
x, attn_weights, present = block(x, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, **extra_kwargs)
|
484 |
+
if presents is not None:
|
485 |
+
presents += (present,)
|
486 |
+
if b_idx in self.kv_cache_layers:
|
487 |
+
layer_kv_cache_dict[b_idx] = [present[0][:, past_position:], present[1][:, past_position:]]
|
488 |
+
if output_attentions:
|
489 |
+
assert all_self_attns is not None
|
490 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
491 |
+
x = self.norm_f(x)
|
492 |
+
if output_hidden_states:
|
493 |
+
assert all_hidden_states is not None
|
494 |
+
all_hidden_states = all_hidden_states + (x,)
|
495 |
+
return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attns)
|
496 |
+
|
497 |
+
def get_sequence_length(self, x: torch.Tensor) -> int:
|
498 |
+
"""Returns the sequence length.
|
499 |
+
|
500 |
+
Args:
|
501 |
+
x (torch.Tensor): The input Tensor.
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
S (int): The sequence length.
|
505 |
+
"""
|
506 |
+
return x.size(1)
|
507 |
+
|
508 |
+
def param_init_fn(self, module: nn.Module) -> None:
|
509 |
+
init_fn_name = self.config.init_config['name']
|
510 |
+
param_init_fns.get(init_fn_name)(module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
|
511 |
+
|
512 |
+
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
|
513 |
+
return _fsdp_wrap_fn(self, module)
|
514 |
+
|
515 |
+
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
|
516 |
+
return isinstance(module, MPTBlock)
|
517 |
+
|
518 |
+
class MPTForCausalLM(MPTPreTrainedModel):
|
519 |
+
_tied_weights_keys = ['lm_head.weight']
|
520 |
+
_tp_plan = {'lm_head': 'colwise_rep'}
|
521 |
+
_pp_plan = {'lm_head': (['hidden_states'], ['logits'])}
|
522 |
+
|
523 |
+
def __init__(self, config: MPTConfig):
|
524 |
+
super().__init__(config)
|
525 |
+
log.info(f'Instantiating an MPTForCausalLM model from {__file__}')
|
526 |
+
self.transformer: MPTModel = self.backbone_model_class(config)
|
527 |
+
self.lm_head = None
|
528 |
+
if not config.tie_word_embeddings:
|
529 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False, device=config.init_device)
|
530 |
+
self.lm_head._fsdp_wrap = True
|
531 |
+
for child in self.transformer.children():
|
532 |
+
if isinstance(child, torch.nn.ModuleList):
|
533 |
+
continue
|
534 |
+
if isinstance(child, torch.nn.Module):
|
535 |
+
child._fsdp_wrap = True
|
536 |
+
self.logit_scale = None
|
537 |
+
if config.logit_scale is not None:
|
538 |
+
logit_scale = config.logit_scale
|
539 |
+
if isinstance(logit_scale, str):
|
540 |
+
if logit_scale == 'inv_sqrt_d_model':
|
541 |
+
logit_scale = 1 / math.sqrt(config.d_model)
|
542 |
+
else:
|
543 |
+
raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
544 |
+
self.logit_scale = logit_scale
|
545 |
+
self.final_logit_softcapping = config.final_logit_softcapping
|
546 |
+
|
547 |
+
@property
|
548 |
+
def backbone_model_class(self) -> type[MPTModel]:
|
549 |
+
return MPTModel
|
550 |
+
|
551 |
+
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
|
552 |
+
return self.transformer.get_input_embeddings()
|
553 |
+
|
554 |
+
def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
|
555 |
+
self.transformer.set_input_embeddings(value)
|
556 |
+
|
557 |
+
def get_output_embeddings(self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
|
558 |
+
if self.lm_head is not None:
|
559 |
+
return self.lm_head
|
560 |
+
return self.transformer.get_input_embeddings()
|
561 |
+
|
562 |
+
def set_output_embeddings(self, new_embeddings: Union[SharedEmbedding, nn.Embedding, nn.Linear]) -> None:
|
563 |
+
if self.lm_head is not None:
|
564 |
+
self.lm_head = new_embeddings
|
565 |
+
else:
|
566 |
+
if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)):
|
567 |
+
raise ValueError('new_embeddings must be an instance of SharedEmbedding ' + f'or nn.Embedding, but got {type(new_embeddings)}.')
|
568 |
+
warnings.warn('Using `set_output_embeddings` to set the embedding layer of ' + 'MPTForCausalLM with tied weights. Given weights are tied, ' + 'using `set_input_embeddings` is recommended over using ' + '`set_output_embeddings`.')
|
569 |
+
self.transformer.set_input_embeddings(new_embeddings)
|
570 |
+
|
571 |
+
def tie_weights(self) -> None:
|
572 |
+
if getattr(self.config, 'tie_word_embeddings', True):
|
573 |
+
self.lm_head = None
|
574 |
+
|
575 |
+
def set_decoder(self, decoder: MPTModel) -> None:
|
576 |
+
self.transformer = decoder
|
577 |
+
|
578 |
+
def get_decoder(self) -> MPTModel:
|
579 |
+
return self.transformer
|
580 |
+
|
581 |
+
def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[list[tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None, position_ids: Optional[torch.LongTensor]=None) -> CausalLMOutputWithPast:
|
582 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
583 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
584 |
+
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds, position_ids=position_ids)
|
585 |
+
if self.lm_head is not None:
|
586 |
+
logits = self.lm_head(outputs.last_hidden_state)
|
587 |
+
else:
|
588 |
+
out = outputs.last_hidden_state
|
589 |
+
out = out.to(self.transformer.wte.weight.device)
|
590 |
+
logits = self.transformer.wte(out, True)
|
591 |
+
if self.logit_scale is not None:
|
592 |
+
if self.logit_scale == 0:
|
593 |
+
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
594 |
+
logits *= self.logit_scale
|
595 |
+
if self.final_logit_softcapping is not None:
|
596 |
+
logits = self.final_logit_softcapping * torch.tanh(logits / self.final_logit_softcapping)
|
597 |
+
loss = None
|
598 |
+
if labels is not None:
|
599 |
+
_labels = torch.roll(labels, shifts=-1)
|
600 |
+
_labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
|
601 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
|
602 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
603 |
+
|
604 |
+
def param_init_fn(self, module: nn.Module) -> None:
|
605 |
+
init_fn_name = self.config.init_config['name']
|
606 |
+
param_init_fns.get(init_fn_name)(module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
|
607 |
+
|
608 |
+
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
|
609 |
+
return _fsdp_wrap_fn(self, module)
|
610 |
+
|
611 |
+
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
|
612 |
+
"""The MPT activation checkpointing (act ckpt) function.
|
613 |
+
|
614 |
+
When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below:
|
615 |
+
1. null (or no such field): The whole MPTBlock will be activation checkpointed on all layers
|
616 |
+
2. a list of modules to act ckpt on all layers, e.g.,
|
617 |
+
activation_checkpointing_target:
|
618 |
+
- grouped_query_attention
|
619 |
+
- mptmlp
|
620 |
+
3. a dictionary of module name with target_blocks, e.g.,
|
621 |
+
activation_checkpointing_target:
|
622 |
+
{
|
623 |
+
"mptblock": target_blocks_1,
|
624 |
+
"grouped_query_attention": target_blocks_2
|
625 |
+
}
|
626 |
+
target_blocks (target_blocks_1, target_blocks_2 above) can be:
|
627 |
+
- a single integer n: the first n transformer block will be activation checkpointed
|
628 |
+
- a string of first-n, middle-m, last-k, range-i-j: the first n, the middle m, the last k, or the range [i, j) layers will be activation checkpointed. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks will be activation checkpointed
|
629 |
+
middle-m is range [start, end) where ``start = max(max_block_idx // 2 - m // 2, 0), end = min(start + m, max_block_idx + 1)``
|
630 |
+
- a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block will be activation checkpointed. [2, 3] means the second and third transformer blocks will be activation checkpointed
|
631 |
+
- a list of mixed integers and strings of first-n, middle-m, last-k, range-i-j
|
632 |
+
|
633 |
+
An example in yaml config file:
|
634 |
+
fsdp_config:
|
635 |
+
activation_checkpointing: true
|
636 |
+
model:
|
637 |
+
activation_checkpointing_target:
|
638 |
+
{
|
639 |
+
"mptblock": 'first-5',
|
640 |
+
"grouped_query_attention": 'last-35'
|
641 |
+
}
|
642 |
+
"""
|
643 |
+
if not hasattr(module, 'block_idx'):
|
644 |
+
log.debug(f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.')
|
645 |
+
return False
|
646 |
+
act_ckpt_target = getattr(self.config, 'activation_checkpointing_target', None)
|
647 |
+
act_ckpt_mod_to_blocks = build_act_ckpt_mod_to_blocks(act_ckpt_target, MPTBlock, module.max_block_idx)
|
648 |
+
check_mapping_blocks_overlap(act_ckpt_mod_to_blocks, module.max_block_idx)
|
649 |
+
for k in act_ckpt_mod_to_blocks.keys():
|
650 |
+
if isinstance(module, k):
|
651 |
+
blocks = act_ckpt_mod_to_blocks[k]
|
652 |
+
return True if blocks == -1 else module.block_idx in blocks
|
653 |
+
return False
|
654 |
+
|
655 |
+
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[list[tuple[torch.Tensor, torch.Tensor]]]=None, inputs_embeds: Optional[torch.Tensor]=None, **kwargs: Any) -> dict[str, Any]:
|
656 |
+
attention_mask = kwargs['attention_mask'].bool()
|
657 |
+
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
658 |
+
raise NotImplementedError('MPT does not support generation with right padding.')
|
659 |
+
if self.transformer.attn_uses_sequence_id and self.training:
|
660 |
+
sequence_id = torch.zeros_like(input_ids[:1])
|
661 |
+
else:
|
662 |
+
sequence_id = None
|
663 |
+
if past_key_values is not None:
|
664 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
665 |
+
if inputs_embeds is not None and past_key_values is None:
|
666 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
667 |
+
else:
|
668 |
+
model_inputs = {'input_ids': input_ids}
|
669 |
+
model_inputs.update({'attention_mask': attention_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)})
|
670 |
+
return model_inputs
|
671 |
+
|
672 |
+
@staticmethod
|
673 |
+
def _reorder_cache(past_key_values: list[tuple[torch.Tensor, torch.Tensor]], beam_idx: torch.LongTensor) -> list[tuple[torch.Tensor, ...]]:
|
674 |
+
"""Used by HuggingFace generate when using beam search with kv-caching.
|
675 |
+
|
676 |
+
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
|
677 |
+
for an example in transformers.
|
678 |
+
"""
|
679 |
+
reordered_past = []
|
680 |
+
for layer_past in past_key_values:
|
681 |
+
reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
|
682 |
+
return reordered_past
|
683 |
+
|
684 |
+
def get_targets(labels: torch.Tensor) -> torch.Tensor:
|
685 |
+
targets = torch.roll(labels, shifts=-1)
|
686 |
+
targets[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
|
687 |
+
return targets
|
688 |
+
|
689 |
+
def compute_loss_from_logits(outputs: CausalLMOutputWithPast, shift_labels: bool, labels: torch.Tensor, loss_fn: nn.Module) -> torch.Tensor:
|
690 |
+
targets = get_targets(labels) if shift_labels else labels
|
691 |
+
losses = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1))
|
692 |
+
if torch.all(targets == loss_fn.ignore_index):
|
693 |
+
loss = losses.sum()
|
694 |
+
else:
|
695 |
+
loss = losses.sum() / (targets != loss_fn.ignore_index).sum()
|
696 |
+
return loss
|
mpt_param_count.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper functions for computing parameter counts for MPT model.
|
2 |
+
|
3 |
+
Use if generic `sum(p.numel() for p in self.parameters())`
|
4 |
+
style computation does not account for MoE parameter sharding.
|
5 |
+
The helper functions in this file account for MoE parameter
|
6 |
+
sharding in the parameter count calculation. The functions below
|
7 |
+
calculate the total parameter count and the active parameter count.
|
8 |
+
Note: MPT has both n_total_params and n_active_params methods.
|
9 |
+
"""
|
10 |
+
from typing import Union
|
11 |
+
from torch import Tensor, nn
|
12 |
+
from torch.distributed._tensor import DTensor
|
13 |
+
from .layers_registry import ffns_with_megablocks
|
14 |
+
|
15 |
+
def module_n_params(module: nn.Module) -> int:
|
16 |
+
"""Gets the number of parameters in this module excluding child modules.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
module (nn.Module): Module of which we get the number of parameters.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
An int for the number of parameters in this module.
|
23 |
+
"""
|
24 |
+
n_params = 0
|
25 |
+
for p in module.parameters(recurse=False):
|
26 |
+
n_params += p.numel()
|
27 |
+
return n_params
|
28 |
+
|
29 |
+
def _dtensor_safe_check_numel(tensor: Union[Tensor, DTensor]) -> int:
|
30 |
+
if isinstance(tensor, DTensor):
|
31 |
+
tensor = tensor._local_tensor
|
32 |
+
return tensor.numel()
|
33 |
+
|
34 |
+
def megablocks_n_total_params(mpt_model) -> int:
|
35 |
+
"""Calculates the number of parameters in a MegaBlocks enabled MPT model.
|
36 |
+
|
37 |
+
MoE experts are sharded across workers. This function scans for MegaBlocks
|
38 |
+
modules then multiplies expert params count by MoE world size.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
|
42 |
+
parameters is calculated.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
An int for the total number of parameters in this MPT model.
|
46 |
+
"""
|
47 |
+
import megablocks
|
48 |
+
moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
|
49 |
+
n_total_params = 0
|
50 |
+
for module in mpt_model.modules():
|
51 |
+
if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
|
52 |
+
n_w1 = _dtensor_safe_check_numel(module.w1)
|
53 |
+
n_total_params += n_w1 * moe_world_size
|
54 |
+
n_w2 = _dtensor_safe_check_numel(module.w2)
|
55 |
+
n_total_params += n_w2 * moe_world_size
|
56 |
+
if hasattr(module, 'v1'):
|
57 |
+
n_v1 = _dtensor_safe_check_numel(module.v1)
|
58 |
+
n_total_params += n_v1 * moe_world_size
|
59 |
+
else:
|
60 |
+
n_total_params += module_n_params(module)
|
61 |
+
return n_total_params
|
62 |
+
|
63 |
+
def megablocks_n_active_params(mpt_model) -> int:
|
64 |
+
"""Calculates the number of active parameters in a MegaBlocks enabled MPT.
|
65 |
+
|
66 |
+
This requires we calculate the number of elements per expert and
|
67 |
+
multiply this by top k.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
|
71 |
+
active parameters is calculated.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
An int for the active number of parameters in this MPT model.
|
75 |
+
"""
|
76 |
+
import megablocks
|
77 |
+
moe_num_experts = mpt_model.config.ffn_config.get('moe_num_experts', 1)
|
78 |
+
moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
|
79 |
+
local_experts = moe_num_experts / moe_world_size
|
80 |
+
moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1)
|
81 |
+
n_active_params = 0
|
82 |
+
for module in mpt_model.modules():
|
83 |
+
if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
|
84 |
+
n_w1 = _dtensor_safe_check_numel(module.w1)
|
85 |
+
n_active_params += int(n_w1 / local_experts * moe_top_k)
|
86 |
+
n_w2 = _dtensor_safe_check_numel(module.w2)
|
87 |
+
n_active_params += int(n_w2 / local_experts * moe_top_k)
|
88 |
+
if hasattr(module, 'v1'):
|
89 |
+
n_v1 = _dtensor_safe_check_numel(module.v1)
|
90 |
+
n_active_params += int(n_v1 / local_experts * moe_top_k)
|
91 |
+
else:
|
92 |
+
n_active_params += module_n_params(module)
|
93 |
+
return n_active_params
|
94 |
+
|
95 |
+
def mpt_get_total_params(mpt_model) -> int:
|
96 |
+
"""Calculates the total parameter count of an MPT model.
|
97 |
+
|
98 |
+
Note: Must be called before model parameters are sharded by FSDP.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
|
102 |
+
active parameters is calculated.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
An int for the total number of parameters in this MPT model.
|
106 |
+
"""
|
107 |
+
if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
|
108 |
+
return megablocks_n_total_params(mpt_model)
|
109 |
+
else:
|
110 |
+
return sum((p.numel() for p in mpt_model.parameters()))
|
111 |
+
|
112 |
+
def mpt_get_active_params(mpt_model) -> int:
|
113 |
+
"""Calculates the total parameter count of an MPT model.
|
114 |
+
|
115 |
+
Note: Must be called before model parameters are sharded by FSDP.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
|
119 |
+
active parameters is calculated.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
An int for the active number of parameters in this MPT model.
|
123 |
+
"""
|
124 |
+
if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
|
125 |
+
params = megablocks_n_active_params(mpt_model)
|
126 |
+
else:
|
127 |
+
params = sum((p.numel() for p in mpt_model.parameters()))
|
128 |
+
if not mpt_model.model.transformer.config.tie_word_embeddings:
|
129 |
+
params -= _dtensor_safe_check_numel(mpt_model.model.transformer.wte.weight)
|
130 |
+
return params
|
norm.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
import torch
|
3 |
+
from .layers_registry import norms
|
4 |
+
norms.register(name='layernorm', func=torch.nn.LayerNorm)
|
5 |
+
|
6 |
+
def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
|
7 |
+
if torch.is_autocast_enabled():
|
8 |
+
if tensor.device.type == 'cuda':
|
9 |
+
dtype = torch.get_autocast_gpu_dtype()
|
10 |
+
elif tensor.device.type == 'cpu':
|
11 |
+
dtype = torch.get_autocast_cpu_dtype()
|
12 |
+
else:
|
13 |
+
raise NotImplementedError()
|
14 |
+
return tensor.to(dtype=dtype)
|
15 |
+
return tensor
|
16 |
+
|
17 |
+
@norms.register_class('low_precision_layernorm')
|
18 |
+
class LPLayerNorm(torch.nn.LayerNorm):
|
19 |
+
|
20 |
+
def __init__(self, normalized_shape: Union[int, list[int], torch.Size], eps: float=1e-05, elementwise_affine: bool=True, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
|
21 |
+
super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
22 |
+
|
23 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
24 |
+
module_device = x.device
|
25 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
26 |
+
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
27 |
+
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
28 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
29 |
+
return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
|
30 |
+
|
31 |
+
def rms_norm(x: torch.Tensor, weight: Optional[torch.Tensor]=None, eps: float=1e-05) -> torch.Tensor:
|
32 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
33 |
+
if weight is not None:
|
34 |
+
return output * weight
|
35 |
+
return output
|
36 |
+
|
37 |
+
@norms.register_class('rmsnorm')
|
38 |
+
class RMSNorm(torch.nn.Module):
|
39 |
+
|
40 |
+
def __init__(self, normalized_shape: Union[int, list[int], torch.Size], eps: float=1e-05, weight: bool=True, dtype: Optional[torch.dtype]=None, device: Optional[torch.device]=None):
|
41 |
+
super().__init__()
|
42 |
+
self.eps = eps
|
43 |
+
if weight:
|
44 |
+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
|
45 |
+
else:
|
46 |
+
self.register_parameter('weight', None)
|
47 |
+
|
48 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
49 |
+
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
50 |
+
|
51 |
+
@norms.register_class('low_precision_rmsnorm')
|
52 |
+
class LPRMSNorm(RMSNorm):
|
53 |
+
|
54 |
+
def __init__(self, normalized_shape: Union[int, list[int], torch.Size], eps: float=1e-05, weight: bool=True, dtype: Optional[torch.dtype]=None, device: Optional[torch.device]=None):
|
55 |
+
super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
58 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
59 |
+
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
60 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
61 |
+
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
62 |
+
|
63 |
+
@norms.register_class('triton_rmsnorm')
|
64 |
+
class TritonRMSNorm(torch.nn.Module):
|
65 |
+
|
66 |
+
def __init__(self, normalized_shape: Union[int, list[int], torch.Size], eps: float=1e-05, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
|
67 |
+
super().__init__()
|
68 |
+
self.eps = eps
|
69 |
+
try:
|
70 |
+
from flash_attn.ops.triton.layer_norm import rms_norm_fn
|
71 |
+
except ImportError:
|
72 |
+
raise ImportError('triton_rms_norm requires Flash Attention to be installed. ' + 'Please pip install flash-attn.')
|
73 |
+
if not isinstance(normalized_shape, int):
|
74 |
+
raise ValueError('TritonRMSNorm only supports 1D tensors')
|
75 |
+
self.rms_norm_fn = rms_norm_fn
|
76 |
+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
|
77 |
+
|
78 |
+
def forward(self, x: torch.Tensor):
|
79 |
+
return self.rms_norm_fn(x, self.weight, None, residual=None, eps=self.eps, dropout_p=0.0, prenorm=False, residual_in_fp32=False)
|
param_init_fns.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from collections.abc import Sequence
|
4 |
+
from copy import deepcopy
|
5 |
+
from functools import partial
|
6 |
+
from typing import Any, Callable, Optional, Union
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.distributed._tensor import DTensor
|
10 |
+
from .layers_registry import fcs, module_init_fns, norms, param_init_fns
|
11 |
+
from .dmoe import GLU, MLP
|
12 |
+
try:
|
13 |
+
import transformer_engine.pytorch as te
|
14 |
+
except:
|
15 |
+
te = None
|
16 |
+
try:
|
17 |
+
import megablocks
|
18 |
+
except:
|
19 |
+
megablocks = None
|
20 |
+
|
21 |
+
def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None:
|
22 |
+
del kwargs
|
23 |
+
if hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable):
|
24 |
+
module.reset_parameters()
|
25 |
+
|
26 |
+
def fused_init_helper_(module: nn.Module, init_fn_: Callable, name_param: str='weight'):
|
27 |
+
"""Initializes parameters which have been fused for efficiency purposes.
|
28 |
+
|
29 |
+
Parameter initialization is often based on the parameters shape. If a layer is fused,
|
30 |
+
initialization should be based on the shapes of the original tensor instead of the
|
31 |
+
shape of the fused tensor. Layers which are fused should have the _fused
|
32 |
+
attribute. First element of _fused is the dimension along which the tensor is fused.
|
33 |
+
Second element is a an iterable of split indices.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
module (nn.Module): The module to initialize.
|
37 |
+
init_fn_ (Callable): Initialization method.
|
38 |
+
name_param (str): Name of parameter to initialize within the module.
|
39 |
+
"""
|
40 |
+
_fused = getattr(module, '_fused', None)
|
41 |
+
if _fused is None:
|
42 |
+
raise RuntimeError(f'Internal logic error')
|
43 |
+
fused_param_init_helper(getattr(module, name_param), init_fn_, _fused)
|
44 |
+
|
45 |
+
def fused_param_init_helper(param: torch.Tensor, init_fn_: Callable, fused_parameters: tuple[int, list[int]]):
|
46 |
+
"""Initializes parameters that are fused together.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
param (torch.Tensor): Tensor to initialize.
|
50 |
+
init_fn_ (Callable): Initialization method.
|
51 |
+
fused_parameters (tuple[int, list[int]]): First element of _fused is the dimension
|
52 |
+
along which the tensor is fused. Second element is a an iterable of split indices.
|
53 |
+
"""
|
54 |
+
p_ndims = param.ndim
|
55 |
+
dim, splits = fused_parameters
|
56 |
+
splits = (0, *splits, param.size(dim))
|
57 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
58 |
+
slice_indices = [slice(None)] * p_ndims
|
59 |
+
slice_indices[dim] = slice(s, e)
|
60 |
+
init_fn_(param[slice_indices])
|
61 |
+
|
62 |
+
def stacked_init_helper_(module: nn.Module, init_fn_: Callable, name_param: str='weight'):
|
63 |
+
"""Initializes parameters stacked along a new dimension.
|
64 |
+
|
65 |
+
Parameter initialization is often based on the parameters shape. If a layer is stacked,
|
66 |
+
initialization should be based on the shapes of the original tensor instead of the
|
67 |
+
shape of the stacked tensor. Layers which are fused should have the _stacked_dim
|
68 |
+
attribute defining the new dimension along which they are stacked.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
module (nn.Module): The module to initialize.
|
72 |
+
init_fn_ (Callable): Initialization method.
|
73 |
+
name_param (str): Name of parameter to initialize within the module.
|
74 |
+
"""
|
75 |
+
stack_dim = getattr(module, '_stack_dim', None)
|
76 |
+
if stack_dim is None:
|
77 |
+
raise RuntimeError(f'Internal logic error')
|
78 |
+
stacked_param_init_helper(getattr(module, name_param), init_fn_, stack_dim)
|
79 |
+
|
80 |
+
def stacked_param_init_helper(param: torch.Tensor, init_fn_: Callable, stack_dim: int):
|
81 |
+
"""Initialize parameters stacked along a new dimension.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
param (torch.Tensor): Tensor to initialize.
|
85 |
+
init_fn_ (Callable): Initialization method.
|
86 |
+
stack_dim (int): Dimension along with parameters are stacked
|
87 |
+
"""
|
88 |
+
p_ndims = param.ndim
|
89 |
+
for idx in range(param.size(stack_dim)):
|
90 |
+
slice_indices = [slice(None)] * p_ndims
|
91 |
+
slice_indices[stack_dim] = idx
|
92 |
+
init_fn_(param[slice_indices])
|
93 |
+
|
94 |
+
def _flip_fan_mode(init_fn_: Callable):
|
95 |
+
"""Changes the mode of an init_fn_.
|
96 |
+
|
97 |
+
init_fn_'s "mode" is set to operate on standard torch modules eg torch.nn.Linear.
|
98 |
+
If a custom layer transposes its weights before they are allied such that it is
|
99 |
+
opposite pytorch's conventions, we must flip the fan mode, from fan_in to fan_out.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
init_fn_ (Callable): Initialization method.
|
103 |
+
"""
|
104 |
+
_init_fn_ = deepcopy(init_fn_)
|
105 |
+
if 'mode' in _init_fn_.keywords:
|
106 |
+
if _init_fn_.keywords['mode'] == 'fan_in':
|
107 |
+
_init_fn_.keywords['mode'] = 'fan_out'
|
108 |
+
elif _init_fn_.keywords['mode'] == 'fan_out':
|
109 |
+
_init_fn_.keywords['mode'] = 'fan_in'
|
110 |
+
return _init_fn_
|
111 |
+
|
112 |
+
def fc_init(module: nn.Module, init_fn_: Callable, init_div_is_residual: Union[int, float, str, bool], div_is_residual: Optional[float], **kwargs: Any) -> bool:
|
113 |
+
del kwargs
|
114 |
+
if isinstance(module, tuple({fcs.get(n) for n in fcs.get_all()})):
|
115 |
+
if hasattr(module, '_fused'):
|
116 |
+
fused_init_helper_(module, init_fn_)
|
117 |
+
else:
|
118 |
+
init_fn_(module.weight)
|
119 |
+
if module.bias is not None:
|
120 |
+
assert isinstance(module.bias, torch.Tensor)
|
121 |
+
torch.nn.init.zeros_(module.bias)
|
122 |
+
if init_div_is_residual is not False and getattr(module, '_is_residual', False):
|
123 |
+
with torch.no_grad():
|
124 |
+
module.weight.div_(div_is_residual)
|
125 |
+
return True
|
126 |
+
return False
|
127 |
+
|
128 |
+
def embedding_init(module: nn.Module, init_fn_: Callable, emb_init_std: Optional[float], emb_init_uniform_lim: Optional[Union[tuple[float, float], float]], **kwargs: Any) -> bool:
|
129 |
+
del kwargs
|
130 |
+
if isinstance(module, nn.Embedding):
|
131 |
+
if emb_init_std is not None:
|
132 |
+
std = emb_init_std
|
133 |
+
if std == 0:
|
134 |
+
warnings.warn(f'Embedding layer initialized to 0.')
|
135 |
+
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
136 |
+
elif emb_init_uniform_lim is not None:
|
137 |
+
lim = emb_init_uniform_lim
|
138 |
+
if isinstance(lim, Sequence):
|
139 |
+
if len(lim) > 2:
|
140 |
+
raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
|
141 |
+
if lim[0] == lim[1]:
|
142 |
+
warnings.warn(f'Embedding layer initialized to {lim[0]}.')
|
143 |
+
else:
|
144 |
+
if lim == 0:
|
145 |
+
warnings.warn(f'Embedding layer initialized to 0.')
|
146 |
+
lim = [-lim, lim]
|
147 |
+
a, b = lim
|
148 |
+
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
149 |
+
else:
|
150 |
+
emb_init_fn_ = init_fn_
|
151 |
+
emb_init_fn_(module.weight)
|
152 |
+
if module.padding_idx is not None:
|
153 |
+
with torch.no_grad():
|
154 |
+
module.weight[module.padding_idx].fill_(0)
|
155 |
+
return True
|
156 |
+
return False
|
157 |
+
|
158 |
+
def norm_init(module: nn.Module, **kwargs: Any) -> bool:
|
159 |
+
del kwargs
|
160 |
+
if isinstance(module, tuple({norms.get(name) for name in norms.get_all()})):
|
161 |
+
if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
|
162 |
+
torch.nn.init.ones_(module.weight)
|
163 |
+
if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor):
|
164 |
+
torch.nn.init.zeros_(module.bias)
|
165 |
+
return True
|
166 |
+
return False
|
167 |
+
|
168 |
+
def multihead_attention_init(module: nn.Module, init_fn_: Callable, d_model: Optional[int], init_div_is_residual: Union[int, float, str, bool], div_is_residual: float, **kwargs: Any) -> bool:
|
169 |
+
del kwargs
|
170 |
+
if isinstance(module, nn.MultiheadAttention):
|
171 |
+
if module._qkv_same_embed_dim:
|
172 |
+
assert module.in_proj_weight is not None
|
173 |
+
assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
|
174 |
+
assert d_model is not None
|
175 |
+
_d = d_model
|
176 |
+
splits = (0, _d, 2 * _d, 3 * _d)
|
177 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
178 |
+
init_fn_(module.in_proj_weight[s:e])
|
179 |
+
else:
|
180 |
+
assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
|
181 |
+
assert module.in_proj_weight is None
|
182 |
+
init_fn_(module.q_proj_weight)
|
183 |
+
init_fn_(module.k_proj_weight)
|
184 |
+
init_fn_(module.v_proj_weight)
|
185 |
+
if module.in_proj_bias is not None:
|
186 |
+
torch.nn.init.zeros_(module.in_proj_bias)
|
187 |
+
if module.bias_k is not None:
|
188 |
+
torch.nn.init.zeros_(module.bias_k)
|
189 |
+
if module.bias_v is not None:
|
190 |
+
torch.nn.init.zeros_(module.bias_v)
|
191 |
+
init_fn_(module.out_proj.weight)
|
192 |
+
if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
|
193 |
+
with torch.no_grad():
|
194 |
+
module.out_proj.weight.div_(div_is_residual)
|
195 |
+
if module.out_proj.bias is not None:
|
196 |
+
torch.nn.init.zeros_(module.out_proj.bias)
|
197 |
+
return True
|
198 |
+
return False
|
199 |
+
|
200 |
+
def te_layernorm_mlp_init(module: nn.Module, init_fn_: Callable, **kwargs: Any) -> bool:
|
201 |
+
del kwargs
|
202 |
+
if te is not None and isinstance(module, te.LayerNormMLP):
|
203 |
+
if isinstance(module.layer_norm_weight, torch.Tensor):
|
204 |
+
torch.nn.init.ones_(module.layer_norm_weight)
|
205 |
+
if isinstance(module.layer_norm_bias, torch.Tensor):
|
206 |
+
torch.nn.init.zeros_(module.layer_norm_bias)
|
207 |
+
init_fn_(module.fc1_weight)
|
208 |
+
if module.fc1_bias is not None:
|
209 |
+
assert isinstance(module.fc1_bias, torch.Tensor)
|
210 |
+
torch.nn.init.zeros_(module.fc1_bias)
|
211 |
+
init_fn_(module.fc2_weight)
|
212 |
+
if module.fc2_bias is not None:
|
213 |
+
assert isinstance(module.fc2_bias, torch.Tensor)
|
214 |
+
torch.nn.init.zeros_(module.fc2_bias)
|
215 |
+
with torch.no_grad():
|
216 |
+
module.fc2_weight.div_(div_is_residual)
|
217 |
+
return True
|
218 |
+
return False
|
219 |
+
|
220 |
+
def moe_init(module: nn.Module, init_fn_: Callable, init_div_is_residual: Union[int, float, str, bool], div_is_residual: float, **kwargs: Any) -> bool:
|
221 |
+
if megablocks is not None and isinstance(module, (megablocks.layers.moe.MoE, megablocks.layers.dmoe.dMoE, megablocks.layers.moe.ParallelMLP, megablocks.layers.dmoe.ParallelDroplessMLP)):
|
222 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
223 |
+
torch.nn.init.zeros_(module.bias)
|
224 |
+
return True
|
225 |
+
elif megablocks is not None and isinstance(module, megablocks.layers.glu.SparseGLU):
|
226 |
+
_megablocks_sparse_glu_generic_param_init_fn_(module, init_fn_, bool(init_div_is_residual), div_is_residual)
|
227 |
+
return True
|
228 |
+
elif megablocks is not None and isinstance(module, megablocks.layers.mlp.SparseMLP):
|
229 |
+
_megablocks_sparse_mlp_generic_param_init_fn_(module, init_fn_, bool(init_div_is_residual), div_is_residual)
|
230 |
+
return True
|
231 |
+
elif megablocks is not None and isinstance(module, megablocks.layers.mlp.MLP):
|
232 |
+
_megablocks_mlp_generic_param_init_fn_(module, init_fn_, bool(init_div_is_residual), div_is_residual)
|
233 |
+
return True
|
234 |
+
elif isinstance(module, GLU):
|
235 |
+
init_fn_(module.w1)
|
236 |
+
init_fn_(module.v1)
|
237 |
+
init_fn_(module.w2)
|
238 |
+
return True
|
239 |
+
elif isinstance(module, MLP):
|
240 |
+
init_fn_(module.w1)
|
241 |
+
init_fn_(module.w2)
|
242 |
+
return True
|
243 |
+
return False
|
244 |
+
|
245 |
+
def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
246 |
+
del kwargs
|
247 |
+
init_div_is_residual = init_div_is_residual
|
248 |
+
if init_div_is_residual is False:
|
249 |
+
div_is_residual = 1.0
|
250 |
+
elif init_div_is_residual is True:
|
251 |
+
div_is_residual = math.sqrt(2 * n_layers)
|
252 |
+
elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
|
253 |
+
div_is_residual = init_div_is_residual
|
254 |
+
elif init_div_is_residual.isnumeric():
|
255 |
+
div_is_residual = float(init_div_is_residual)
|
256 |
+
else:
|
257 |
+
div_is_residual = 1.0
|
258 |
+
raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
|
259 |
+
all_module_init_fns = [module_init_fns.get(name) for name in module_init_fns.get_all()]
|
260 |
+
did_init = False
|
261 |
+
for module_init_fn in all_module_init_fns:
|
262 |
+
did_init = module_init_fn(module=module, init_fn_=init_fn_, d_model=d_model, init_div_is_residual=init_div_is_residual, div_is_residual=div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
263 |
+
if did_init:
|
264 |
+
break
|
265 |
+
if not did_init:
|
266 |
+
for _ in module.parameters(recurse=False):
|
267 |
+
raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by any of the registered module_init_fns. ' + 'Please add an appropriate module_init_fn to the registry. Currently registered module_init_fns are: ' + ', '.join(module_init_fns.get_all()))
|
268 |
+
|
269 |
+
def _megablocks_sparse_mlp_generic_param_init_fn_(module: nn.Module, init_fn_: Callable, init_div_is_residual: bool=False, div_is_residual: float=1.0):
|
270 |
+
"""Initializes MegaBlocks MLP.
|
271 |
+
|
272 |
+
To enable elastic deterministic initialization, this method creates the entire
|
273 |
+
weight matrix then slice into the weight tensors such that the sampled weights
|
274 |
+
should not vary between moe world size for the same random seed.
|
275 |
+
|
276 |
+
Args:
|
277 |
+
module (nn.Module): The module to initialize.
|
278 |
+
init_fn_ (Callable): Initialization method.
|
279 |
+
init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual
|
280 |
+
flag to be divided by div_is_residual.
|
281 |
+
div_is_residual (float): The value by which parameter initialization is divided
|
282 |
+
if init_div_is_residual flag is enabled.
|
283 |
+
"""
|
284 |
+
expert_process_group_size, rank = (1, 0)
|
285 |
+
if module.expert_parallel_group is not None:
|
286 |
+
expert_process_group_size = int(module.expert_parallel_group.size())
|
287 |
+
rank = int(module.expert_parallel_group.rank())
|
288 |
+
hidden_size = int(module.hidden_size)
|
289 |
+
w1 = module.w1
|
290 |
+
if isinstance(w1, DTensor):
|
291 |
+
w1 = w1._local_tensor
|
292 |
+
w1_size = list(w1.shape)
|
293 |
+
w1_size[0] = w1_size[0] * expert_process_group_size
|
294 |
+
n_exp = w1_size[0] // hidden_size
|
295 |
+
_fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)])
|
296 |
+
_w1 = w1.new_empty(w1_size)
|
297 |
+
fused_param_init_helper(_w1, init_fn_, _fused)
|
298 |
+
_w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank]
|
299 |
+
with torch.no_grad():
|
300 |
+
w1.copy_(_w1_local)
|
301 |
+
w2 = module.w2
|
302 |
+
if isinstance(w2, DTensor):
|
303 |
+
w2 = w2._local_tensor
|
304 |
+
w2_size = list(w2.shape)
|
305 |
+
w2_size[0] = w2_size[0] * expert_process_group_size
|
306 |
+
_w2 = w2.new_empty(w2_size)
|
307 |
+
fused_param_init_helper(_w2, _flip_fan_mode(init_fn_), _fused)
|
308 |
+
_w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank]
|
309 |
+
with torch.no_grad():
|
310 |
+
w2.copy_(_w2_local)
|
311 |
+
if init_div_is_residual is not False:
|
312 |
+
with torch.no_grad():
|
313 |
+
w2.div_(div_is_residual)
|
314 |
+
|
315 |
+
def _megablocks_sparse_glu_generic_param_init_fn_(module: nn.Module, init_fn_: Callable, init_div_is_residual: bool=False, div_is_residual: float=1.0):
|
316 |
+
"""Initializes MegaBlocks Sparse GLU.
|
317 |
+
|
318 |
+
Extends the Megablocks Sparse MLP case to an additional weight v1 for GLUs.
|
319 |
+
This additional weight v1 has the same initialization procedure as w1 for MLPs.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
module (nn.Module): The module to initialize.
|
323 |
+
init_fn_ (Callable): Initialization method.
|
324 |
+
init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual
|
325 |
+
flag to be divided by div_is_residual.
|
326 |
+
div_is_residual (float): The value by which parameter initialization is divided
|
327 |
+
if init_div_is_residual flag is enabled.
|
328 |
+
"""
|
329 |
+
_megablocks_sparse_mlp_generic_param_init_fn_(module=module, init_fn_=init_fn_, init_div_is_residual=init_div_is_residual, div_is_residual=div_is_residual)
|
330 |
+
expert_process_group_size, rank = (1, 0)
|
331 |
+
if module.expert_parallel_group is not None:
|
332 |
+
expert_process_group_size = int(module.expert_parallel_group.size())
|
333 |
+
rank = int(module.expert_parallel_group.rank())
|
334 |
+
hidden_size = int(module.hidden_size)
|
335 |
+
v1 = module.v1
|
336 |
+
if isinstance(v1, DTensor):
|
337 |
+
v1 = v1._local_tensor
|
338 |
+
v1_size = list(v1.shape)
|
339 |
+
v1_size[0] = v1_size[0] * expert_process_group_size
|
340 |
+
n_exp = v1_size[0] // hidden_size
|
341 |
+
_fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)])
|
342 |
+
_v1 = v1.new_empty(v1_size)
|
343 |
+
fused_param_init_helper(_v1, init_fn_, _fused)
|
344 |
+
_v1_local = _v1.chunk(expert_process_group_size, dim=0)[rank]
|
345 |
+
with torch.no_grad():
|
346 |
+
v1.copy_(_v1_local)
|
347 |
+
|
348 |
+
def _megablocks_mlp_generic_param_init_fn_(module: nn.Module, init_fn_: Callable, init_div_is_residual: bool=False, div_is_residual: float=1.0):
|
349 |
+
"""Initializes MegaBlocks' MLP.
|
350 |
+
|
351 |
+
To enable elastic deterministic initialization, this method creates the entire
|
352 |
+
weight matrix then slice into the weight tensors such that the sampled weights
|
353 |
+
should not vary between moe world size for the same random seed.
|
354 |
+
|
355 |
+
Args:
|
356 |
+
module (nn.Module): The module to initialize.
|
357 |
+
init_fn_ (Callable): Initialization method.
|
358 |
+
init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual
|
359 |
+
flag to be divided by div_is_residual.
|
360 |
+
div_is_residual (float): The value by which parameter initialization is divided
|
361 |
+
if init_div_is_residual flag is enabled.
|
362 |
+
"""
|
363 |
+
expert_process_group_size, rank = (1, 0)
|
364 |
+
if module.expert_parallel_group is not None:
|
365 |
+
expert_process_group_size = int(module.expert_parallel_group.size())
|
366 |
+
rank = int(module.expert_parallel_group.rank())
|
367 |
+
_init_fn_ = _flip_fan_mode(init_fn_)
|
368 |
+
w1_size = list(module.w1.shape)
|
369 |
+
w1_size[0] = w1_size[0] * expert_process_group_size
|
370 |
+
_w1 = module.w1.new_empty(w1_size)
|
371 |
+
stacked_param_init_helper(_w1, _init_fn_, module._stack_dim)
|
372 |
+
_w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank]
|
373 |
+
with torch.no_grad():
|
374 |
+
module.w1.copy_(_w1_local)
|
375 |
+
w2_size = list(module.w2.shape)
|
376 |
+
w2_size[0] = w2_size[0] * expert_process_group_size
|
377 |
+
_w2 = module.w2.new_empty(w2_size)
|
378 |
+
stacked_param_init_helper(_w2, _init_fn_, module._stack_dim)
|
379 |
+
_w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank]
|
380 |
+
with torch.no_grad():
|
381 |
+
module.w2.copy_(_w2_local)
|
382 |
+
if init_div_is_residual is not False:
|
383 |
+
with torch.no_grad():
|
384 |
+
module.w2.div_(div_is_residual)
|
385 |
+
|
386 |
+
def _normal_init_(std: float, mean: float=0.0):
|
387 |
+
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
388 |
+
|
389 |
+
def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
390 |
+
del kwargs
|
391 |
+
init_fn_ = _normal_init_(std=std)
|
392 |
+
generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
393 |
+
|
394 |
+
def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
395 |
+
del kwargs
|
396 |
+
if init_std is None:
|
397 |
+
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
398 |
+
_normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
399 |
+
|
400 |
+
def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
401 |
+
del kwargs
|
402 |
+
std = math.sqrt(2 / (5 * d_model))
|
403 |
+
_normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
404 |
+
|
405 |
+
def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
406 |
+
"""From section 2.3.1 of GPT-NeoX-20B:
|
407 |
+
|
408 |
+
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
409 |
+
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
|
410 |
+
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
|
411 |
+
"""
|
412 |
+
del kwargs
|
413 |
+
residual_div = n_layers / math.sqrt(10)
|
414 |
+
small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
415 |
+
|
416 |
+
def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
|
417 |
+
del kwargs
|
418 |
+
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
419 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
420 |
+
|
421 |
+
def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
|
422 |
+
del kwargs
|
423 |
+
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
424 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
425 |
+
|
426 |
+
def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
|
427 |
+
del kwargs
|
428 |
+
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
429 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
430 |
+
|
431 |
+
def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
|
432 |
+
del kwargs
|
433 |
+
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
434 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
435 |
+
param_init_fns.register('default_', func=torch_default_param_init_fn_)
|
436 |
+
param_init_fns.register('baseline_', func=baseline_param_init_fn_)
|
437 |
+
param_init_fns.register('kaiming_uniform_', func=kaiming_uniform_param_init_fn_)
|
438 |
+
param_init_fns.register('kaiming_normal_', func=kaiming_normal_param_init_fn_)
|
439 |
+
param_init_fns.register('neox_init_', func=neox_param_init_fn_)
|
440 |
+
param_init_fns.register('small_init_', func=small_param_init_fn_)
|
441 |
+
param_init_fns.register('xavier_uniform_', func=xavier_uniform_param_init_fn_)
|
442 |
+
param_init_fns.register('xavier_normal_', func=xavier_normal_param_init_fn_)
|
443 |
+
module_init_fns.register('fc', func=fc_init)
|
444 |
+
module_init_fns.register('embedding', func=embedding_init)
|
445 |
+
module_init_fns.register('norm', func=norm_init)
|
446 |
+
module_init_fns.register('multihead_attention', func=multihead_attention_init)
|
447 |
+
module_init_fns.register('te_layernorm_mlp', func=te_layernorm_mlp_init)
|
448 |
+
module_init_fns.register('moe', func=moe_init)
|
registry_utils.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import functools
|
3 |
+
import importlib.util
|
4 |
+
import os
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from pathlib import Path
|
7 |
+
from types import ModuleType
|
8 |
+
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union
|
9 |
+
import catalogue
|
10 |
+
T = TypeVar('T')
|
11 |
+
TypeBoundT = TypeVar('TypeBoundT', bound=type)
|
12 |
+
CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any])
|
13 |
+
|
14 |
+
class TypedRegistry(catalogue.Registry, Generic[T]):
|
15 |
+
"""A thin wrapper around catalogue.Registry to add static typing and.
|
16 |
+
|
17 |
+
descriptions.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, namespace: Sequence[str], entry_points: bool=False, description: str='') -> None:
|
21 |
+
super().__init__(namespace, entry_points=entry_points)
|
22 |
+
self.description = description
|
23 |
+
|
24 |
+
def __call__(self, name: str, func: Optional[T]=None) -> Callable[[T], T]:
|
25 |
+
return super().__call__(name, func)
|
26 |
+
|
27 |
+
def register(self, name: str, *, func: Optional[T]=None) -> T:
|
28 |
+
return super().register(name, func=func)
|
29 |
+
|
30 |
+
def register_class(self, name: str, *, func: Optional[TypeBoundT]=None) -> TypeBoundT:
|
31 |
+
return super().register(name, func=func)
|
32 |
+
|
33 |
+
def get(self, name: str) -> T:
|
34 |
+
return super().get(name)
|
35 |
+
|
36 |
+
def get_all(self) -> dict[str, T]:
|
37 |
+
return super().get_all()
|
38 |
+
|
39 |
+
def get_entry_point(self, name: str, default: Optional[T]=None) -> T:
|
40 |
+
return super().get_entry_point(name, default=default)
|
41 |
+
|
42 |
+
def get_entry_points(self) -> dict[str, T]:
|
43 |
+
return super().get_entry_points()
|
44 |
+
S = TypeVar('S')
|
45 |
+
|
46 |
+
def create_registry(*namespace: str, generic_type: type[S], entry_points: bool=False, description: str='') -> 'TypedRegistry[S]':
|
47 |
+
"""Create a new registry.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
namespace (str): The namespace, e.g. "llmfoundry.loggers"
|
51 |
+
generic_type (Type[S]): The type of the registry.
|
52 |
+
entry_points (bool): Accept registered functions from entry points.
|
53 |
+
description (str): A description of the registry.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
The TypedRegistry object.
|
57 |
+
"""
|
58 |
+
if catalogue.check_exists(*namespace):
|
59 |
+
raise catalogue.RegistryError(f'Namespace already exists: {namespace}')
|
60 |
+
return TypedRegistry[generic_type](namespace, entry_points=entry_points, description=description)
|
61 |
+
|
62 |
+
def construct_from_registry(name: str, registry: TypedRegistry, partial_function: bool=True, pre_validation_function: Optional[Union[Callable[[Any], None], type]]=None, post_validation_function: Optional[Callable[[Any], None]]=None, kwargs: Optional[dict[str, Any]]=None) -> Any:
|
63 |
+
"""Helper function to build an item from the registry.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
name (str): The name of the registered item
|
67 |
+
registry (catalogue.Registry): The registry to fetch the item from
|
68 |
+
partial_function (bool, optional): Whether to return a partial function for registered callables. Defaults to True.
|
69 |
+
pre_validation_function (Optional[Union[Callable[[Any], None], type]], optional): An optional validation function called
|
70 |
+
before constructing the item to return. This should throw an exception if validation fails. Defaults to None.
|
71 |
+
post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after
|
72 |
+
constructing the item to return. This should throw an exception if validation fails. Defaults to None.
|
73 |
+
kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments.
|
74 |
+
|
75 |
+
Raises:
|
76 |
+
ValueError: If the validation functions failed or the registered item is invalid
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Any: The constructed item from the registry
|
80 |
+
"""
|
81 |
+
if kwargs is None:
|
82 |
+
kwargs = {}
|
83 |
+
registered_constructor = registry.get(name)
|
84 |
+
if pre_validation_function is not None:
|
85 |
+
if isinstance(pre_validation_function, type):
|
86 |
+
if not issubclass(registered_constructor, pre_validation_function):
|
87 |
+
raise ValueError(f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}')
|
88 |
+
elif isinstance(pre_validation_function, Callable):
|
89 |
+
pre_validation_function(registered_constructor)
|
90 |
+
else:
|
91 |
+
raise ValueError(f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}')
|
92 |
+
if isinstance(registered_constructor, type) or (callable(registered_constructor) and (not partial_function)):
|
93 |
+
constructed_item = registered_constructor(**kwargs)
|
94 |
+
elif callable(registered_constructor):
|
95 |
+
constructed_item = functools.partial(registered_constructor, **kwargs)
|
96 |
+
else:
|
97 |
+
raise ValueError(f'Expected {name} to be a class or function, but got {type(registered_constructor)}')
|
98 |
+
if post_validation_function is not None:
|
99 |
+
post_validation_function(constructed_item)
|
100 |
+
return constructed_item
|
101 |
+
|
102 |
+
def import_file(loc: Union[str, Path]) -> ModuleType:
|
103 |
+
"""Import module from a file.
|
104 |
+
|
105 |
+
Used to run arbitrary python code.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
name (str): Name of module to load.
|
109 |
+
loc (str / Path): Path to the file.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
ModuleType: The module object.
|
113 |
+
"""
|
114 |
+
if not os.path.exists(loc):
|
115 |
+
raise FileNotFoundError(f'File {loc} does not exist.')
|
116 |
+
spec = importlib.util.spec_from_file_location('python_code', str(loc))
|
117 |
+
assert spec is not None
|
118 |
+
assert spec.loader is not None
|
119 |
+
module = importlib.util.module_from_spec(spec)
|
120 |
+
try:
|
121 |
+
spec.loader.exec_module(module)
|
122 |
+
except Exception as e:
|
123 |
+
raise RuntimeError(f'Error executing {loc}') from e
|
124 |
+
return module
|
125 |
+
|
126 |
+
@contextmanager
|
127 |
+
def save_registry():
|
128 |
+
"""Save the registry state and restore after the context manager exits."""
|
129 |
+
saved_registry_state = copy.deepcopy(catalogue.REGISTRY)
|
130 |
+
yield
|
131 |
+
catalogue.REGISTRY = saved_registry_state
|
special_tokens_map.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_start|>",
|
4 |
+
"<|im_end|>",
|
5 |
+
"<|object_ref_start|>",
|
6 |
+
"<|object_ref_end|>",
|
7 |
+
"<|box_start|>",
|
8 |
+
"<|box_end|>",
|
9 |
+
"<|quad_start|>",
|
10 |
+
"<|quad_end|>",
|
11 |
+
"<|vision_start|>",
|
12 |
+
"<|vision_end|>",
|
13 |
+
"<|vision_pad|>",
|
14 |
+
"<|image_pad|>",
|
15 |
+
"<|video_pad|>"
|
16 |
+
],
|
17 |
+
"eos_token": {
|
18 |
+
"content": "<|im_end|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": false,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
},
|
24 |
+
"pad_token": {
|
25 |
+
"content": "<|endoftext|>",
|
26 |
+
"lstrip": false,
|
27 |
+
"normalized": false,
|
28 |
+
"rstrip": false,
|
29 |
+
"single_word": false
|
30 |
+
}
|
31 |
+
}
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
|
3 |
+
size 11422654
|
tokenizer_config.json
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_prefix_space": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"151643": {
|
6 |
+
"content": "<|endoftext|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"151644": {
|
14 |
+
"content": "<|im_start|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"151645": {
|
22 |
+
"content": "<|im_end|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"151646": {
|
30 |
+
"content": "<|object_ref_start|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"151647": {
|
38 |
+
"content": "<|object_ref_end|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"151648": {
|
46 |
+
"content": "<|box_start|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"151649": {
|
54 |
+
"content": "<|box_end|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"151650": {
|
62 |
+
"content": "<|quad_start|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"151651": {
|
70 |
+
"content": "<|quad_end|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"151652": {
|
78 |
+
"content": "<|vision_start|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"151653": {
|
86 |
+
"content": "<|vision_end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"151654": {
|
94 |
+
"content": "<|vision_pad|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"151655": {
|
102 |
+
"content": "<|image_pad|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"151656": {
|
110 |
+
"content": "<|video_pad|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"151657": {
|
118 |
+
"content": "<tool_call>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": false
|
124 |
+
},
|
125 |
+
"151658": {
|
126 |
+
"content": "</tool_call>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": false,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false,
|
131 |
+
"special": false
|
132 |
+
},
|
133 |
+
"151659": {
|
134 |
+
"content": "<|fim_prefix|>",
|
135 |
+
"lstrip": false,
|
136 |
+
"normalized": false,
|
137 |
+
"rstrip": false,
|
138 |
+
"single_word": false,
|
139 |
+
"special": false
|
140 |
+
},
|
141 |
+
"151660": {
|
142 |
+
"content": "<|fim_middle|>",
|
143 |
+
"lstrip": false,
|
144 |
+
"normalized": false,
|
145 |
+
"rstrip": false,
|
146 |
+
"single_word": false,
|
147 |
+
"special": false
|
148 |
+
},
|
149 |
+
"151661": {
|
150 |
+
"content": "<|fim_suffix|>",
|
151 |
+
"lstrip": false,
|
152 |
+
"normalized": false,
|
153 |
+
"rstrip": false,
|
154 |
+
"single_word": false,
|
155 |
+
"special": false
|
156 |
+
},
|
157 |
+
"151662": {
|
158 |
+
"content": "<|fim_pad|>",
|
159 |
+
"lstrip": false,
|
160 |
+
"normalized": false,
|
161 |
+
"rstrip": false,
|
162 |
+
"single_word": false,
|
163 |
+
"special": false
|
164 |
+
},
|
165 |
+
"151663": {
|
166 |
+
"content": "<|repo_name|>",
|
167 |
+
"lstrip": false,
|
168 |
+
"normalized": false,
|
169 |
+
"rstrip": false,
|
170 |
+
"single_word": false,
|
171 |
+
"special": false
|
172 |
+
},
|
173 |
+
"151664": {
|
174 |
+
"content": "<|file_sep|>",
|
175 |
+
"lstrip": false,
|
176 |
+
"normalized": false,
|
177 |
+
"rstrip": false,
|
178 |
+
"single_word": false,
|
179 |
+
"special": false
|
180 |
+
},
|
181 |
+
"151665": {
|
182 |
+
"content": "<tool_response>",
|
183 |
+
"lstrip": false,
|
184 |
+
"normalized": false,
|
185 |
+
"rstrip": false,
|
186 |
+
"single_word": false,
|
187 |
+
"special": false
|
188 |
+
},
|
189 |
+
"151666": {
|
190 |
+
"content": "</tool_response>",
|
191 |
+
"lstrip": false,
|
192 |
+
"normalized": false,
|
193 |
+
"rstrip": false,
|
194 |
+
"single_word": false,
|
195 |
+
"special": false
|
196 |
+
},
|
197 |
+
"151667": {
|
198 |
+
"content": "<think>",
|
199 |
+
"lstrip": false,
|
200 |
+
"normalized": false,
|
201 |
+
"rstrip": false,
|
202 |
+
"single_word": false,
|
203 |
+
"special": false
|
204 |
+
},
|
205 |
+
"151668": {
|
206 |
+
"content": "</think>",
|
207 |
+
"lstrip": false,
|
208 |
+
"normalized": false,
|
209 |
+
"rstrip": false,
|
210 |
+
"single_word": false,
|
211 |
+
"special": false
|
212 |
+
}
|
213 |
+
},
|
214 |
+
"additional_special_tokens": [
|
215 |
+
"<|im_start|>",
|
216 |
+
"<|im_end|>",
|
217 |
+
"<|object_ref_start|>",
|
218 |
+
"<|object_ref_end|>",
|
219 |
+
"<|box_start|>",
|
220 |
+
"<|box_end|>",
|
221 |
+
"<|quad_start|>",
|
222 |
+
"<|quad_end|>",
|
223 |
+
"<|vision_start|>",
|
224 |
+
"<|vision_end|>",
|
225 |
+
"<|vision_pad|>",
|
226 |
+
"<|image_pad|>",
|
227 |
+
"<|video_pad|>"
|
228 |
+
],
|
229 |
+
"bos_token": null,
|
230 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
231 |
+
"clean_up_tokenization_spaces": false,
|
232 |
+
"eos_token": "<|im_end|>",
|
233 |
+
"errors": "replace",
|
234 |
+
"extra_special_tokens": {},
|
235 |
+
"model_max_length": 4096,
|
236 |
+
"pad_token": "<|endoftext|>",
|
237 |
+
"split_special_tokens": false,
|
238 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
239 |
+
"unk_token": null
|
240 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
warnings.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import warnings
|
3 |
+
from typing import Any, Callable, Type, TypeVar, cast
|
4 |
+
|
5 |
+
class VersionedDeprecationWarning(UserWarning):
|
6 |
+
"""A custom deprecation warning class that includes version information.
|
7 |
+
|
8 |
+
Attributes:
|
9 |
+
message (str): The deprecation message describing why the feature is deprecated.
|
10 |
+
remove_version (str): The version in which the feature will be removed.
|
11 |
+
|
12 |
+
Example:
|
13 |
+
>>> def deprecated_function():
|
14 |
+
... warnings.warn(
|
15 |
+
... VersionedDeprecationWarning(
|
16 |
+
... "Function XYZ is deprecated.",
|
17 |
+
... remove_version="2.0.0"
|
18 |
+
... )
|
19 |
+
... )
|
20 |
+
...
|
21 |
+
>>> deprecated_function()
|
22 |
+
DeprecationWarning: Function XYZ is deprecated. It will be removed in version 2.0.0.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, message: str, remove_version: str) -> None:
|
26 |
+
super().__init__(message + f' It will be removed in version {remove_version}.')
|
27 |
+
|
28 |
+
class ExperimentalWarning(Warning):
|
29 |
+
"""A warning for experimental features.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
feature_name (str): The name of the experimental feature.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, feature_name: str) -> None:
|
36 |
+
super().__init__(f'{feature_name} is experimental and may change with future versions.')
|
37 |
+
F = TypeVar('F', bound=Callable[..., Any])
|
38 |
+
|
39 |
+
def experimental_function(feature_name: str) -> Callable[[F], F]:
|
40 |
+
"""Decorator to mark a function as experimental.
|
41 |
+
|
42 |
+
The message displayed will be {feature_name} is experimental and may change with future versions.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
feature_name (str): The name of the experimental feature.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
The decorated function.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def decorator(func: Callable):
|
52 |
+
|
53 |
+
@functools.wraps(func)
|
54 |
+
def wrapper(*args: Any, **kwargs: Any):
|
55 |
+
warnings.warn(ExperimentalWarning(feature_name))
|
56 |
+
return func(*args, **kwargs)
|
57 |
+
return cast(F, wrapper)
|
58 |
+
return decorator
|
59 |
+
|
60 |
+
def experimental_class(feature_name: str) -> Callable[[Type], Type]:
|
61 |
+
"""Class decorator to mark a class as experimental."""
|
62 |
+
|
63 |
+
def class_decorator(cls: Type):
|
64 |
+
original_init = cls.__init__
|
65 |
+
cls.is_experimental = True
|
66 |
+
|
67 |
+
def new_init(self: Any, *args: Any, **kwargs: Any):
|
68 |
+
warnings.warn(ExperimentalWarning(feature_name))
|
69 |
+
original_init(self, *args, **kwargs)
|
70 |
+
cls.__init__ = new_init
|
71 |
+
return cls
|
72 |
+
return class_decorator
|