Update starvector architecture file
Browse files- starvector_arch.py +5 -4
starvector_arch.py
CHANGED
@@ -2,7 +2,7 @@ from transformers import (
|
|
2 |
PretrainedConfig,
|
3 |
PreTrainedModel
|
4 |
)
|
5 |
-
|
6 |
class StarVectorConfig(PretrainedConfig):
|
7 |
model_type = "starvector"
|
8 |
|
@@ -18,9 +18,10 @@ class StarVectorConfig(PretrainedConfig):
|
|
18 |
use_cache: bool = True,
|
19 |
num_attention_heads: int = 16,
|
20 |
num_hidden_layers: int = 24,
|
21 |
-
vocab_size: int =
|
22 |
-
hidden_size: int =
|
23 |
num_kv_heads: int = 4,
|
|
|
24 |
**kwargs,
|
25 |
):
|
26 |
self.starcoder_model_name = starcoder_model_name
|
@@ -36,7 +37,7 @@ class StarVectorConfig(PretrainedConfig):
|
|
36 |
self.vocab_size = vocab_size
|
37 |
self.hidden_size = hidden_size
|
38 |
self.num_kv_heads = num_kv_heads
|
39 |
-
|
40 |
super().__init__(**kwargs)
|
41 |
|
42 |
class StarVectorForCausalLM(PreTrainedModel):
|
|
|
2 |
PretrainedConfig,
|
3 |
PreTrainedModel
|
4 |
)
|
5 |
+
import torch
|
6 |
class StarVectorConfig(PretrainedConfig):
|
7 |
model_type = "starvector"
|
8 |
|
|
|
18 |
use_cache: bool = True,
|
19 |
num_attention_heads: int = 16,
|
20 |
num_hidden_layers: int = 24,
|
21 |
+
vocab_size: int = 49152,
|
22 |
+
hidden_size: int = 2048,
|
23 |
num_kv_heads: int = 4,
|
24 |
+
torch_dtype: str = "bfloat16",
|
25 |
**kwargs,
|
26 |
):
|
27 |
self.starcoder_model_name = starcoder_model_name
|
|
|
37 |
self.vocab_size = vocab_size
|
38 |
self.hidden_size = hidden_size
|
39 |
self.num_kv_heads = num_kv_heads
|
40 |
+
self.torch_dtype = torch_dtype
|
41 |
super().__init__(**kwargs)
|
42 |
|
43 |
class StarVectorForCausalLM(PreTrainedModel):
|