appledora commited on
Commit
7090542
·
verified ·
1 Parent(s): 8627f46

Upload configuration_recast_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_recast_llama.py +79 -0
configuration_recast_llama.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RECAST8b_llama(PretrainedConfig):
5
+ model_type = "recast8b_llama"
6
+ attribute_map = {
7
+ "hidden_size": "hidden_size",
8
+ "num_attention_heads": "num_attention_heads",
9
+ }
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=128256,
14
+ hidden_size=2048,
15
+ intermediate_size=8192,
16
+ num_hidden_layers=16,
17
+ num_attention_heads=32,
18
+ num_key_value_heads=8,
19
+ hidden_act="silu",
20
+ max_position_embeddings=131072,
21
+ initializer_range=0.02,
22
+ rms_norm_eps=1e-5,
23
+ use_cache=True,
24
+ pad_token_id=None,
25
+ bos_token_id=128000,
26
+ eos_token_id=128001,
27
+ pretraining_tp=1,
28
+ tie_word_embeddings=False,
29
+ rope_theta=500000.0,
30
+ rope_scaling={
31
+ "factor": 32.0,
32
+ "low_freq_factor": 1.0,
33
+ "high_freq_factor": 4.0,
34
+ "original_max_position_embeddings": 8192,
35
+ "rope_type": "llama3",
36
+ },
37
+ attention_bias=False,
38
+ attention_dropout=0.0,
39
+ mlp_bias=False,
40
+ # Template-specific configs
41
+ num_templates=2,
42
+ num_groups=8,
43
+ coef_height=4,
44
+ num_cf=1,
45
+ torch_dtype="bfloat16",
46
+ **kwargs
47
+ ):
48
+ self.vocab_size = vocab_size
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.hidden_size = hidden_size
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_attention_heads = num_attention_heads
54
+ self.num_key_value_heads = num_key_value_heads
55
+ self.hidden_act = hidden_act
56
+ self.initializer_range = initializer_range
57
+ self.rms_norm_eps = rms_norm_eps
58
+ self.pretraining_tp = pretraining_tp
59
+ self.use_cache = use_cache
60
+ self.mlp_bias = mlp_bias
61
+ self.attention_bias = attention_bias
62
+ self.attention_dropout = attention_dropout
63
+ self.rope_theta = rope_theta
64
+ self.rope_scaling = rope_scaling
65
+ self.torch_dtype = torch_dtype
66
+
67
+ # Template-specific configs
68
+ self.num_templates = num_templates
69
+ self.num_groups = num_groups
70
+ self.coef_height = coef_height
71
+ self.num_cf = num_cf
72
+
73
+ super().__init__(
74
+ pad_token_id=pad_token_id,
75
+ bos_token_id=bos_token_id,
76
+ eos_token_id=eos_token_id,
77
+ tie_word_embeddings=tie_word_embeddings,
78
+ **kwargs
79
+ )