from transformers import PretrainedConfig from typing import List #STARCODER2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} class ModularStarEncoderConfig(PretrainedConfig): model_type = "ModularStarEncoder" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, attention_dropout= 0.1, residual_dropout= 0.1, embedding_dropout= 0.1, bos_token_id= 0, eos_token_id= 0, hidden_act= "gelu_pytorch_tanh", _attn_implementation="flash_attention_2", hidden_size= 1024, conditional_size= 4, initializer_range= 0.018042, intermediate_size= 12288, max_position_embeddings= 2048, mlp_type= "default", model_type= "starcoder2", torch_dtype= "bfloat16", layer_matryoshka_loss= True, matryoshka_layers= [4,9,18,27,36], norm_epsilon= 1e-05, layer_norm_eps=1e-05, norm_type= "layer_norm", num_attention_heads= 16, num_hidden_layers= 36, num_key_value_heads= 4, rope_theta= 999999.4420358813, sliding_window= None, transformers_version= "4.39.3", use_bias= True, use_cache= False, vocab_size= 49156, pad_token_id=0, **kwargs, ): if _attn_implementation not in ["flash_attention_2", "sdpa"]: raise ValueError(f"`_attn_implementation` must be 'flash_attention_2', 'sdpa', got {_attn_implementation}.") self.attention_dropout=attention_dropout , self.residual_dropout= residual_dropout, self.embedding_dropout= embedding_dropout, self.bos_token_id= bos_token_id, self.eos_token_id= eos_token_id, self.hidden_act= hidden_act, self._attn_implementation=_attn_implementation, self.hidden_size= hidden_size, self.conditional_size= conditional_size, self.initializer_range= initializer_range, self.intermediate_size= intermediate_size, self.max_position_embeddings= max_position_embeddings, self.mlp_type= mlp_type, self.model_type= model_type, self.torch_dtype= torch_dtype, self.layer_matryoshka_loss= layer_matryoshka_loss, self.matryoshka_layers= matryoshka_layers, self.norm_epsilon= norm_epsilon, self.layer_norm_eps=layer_norm_eps, self.norm_type= norm_type, self.num_attention_heads= num_attention_heads, self.num_hidden_layers= num_hidden_layers, self.num_key_value_heads= num_key_value_heads, self.rope_theta= rope_theta, self.sliding_window= sliding_window, self.transformers_version= transformers_version, self.use_bias= use_bias, self.use_cache= use_cache, self.vocab_size= vocab_size, self.pad_token_id=pad_token_id, super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)