from transformers import ( PretrainedConfig, PreTrainedModel ) import torch class StarVectorConfig(PretrainedConfig): model_type = "starvector" def __init__( self, starcoder_model_name: str = "bigcode/starcoderbase-1b", image_encoder_type: str = "clip", adapter_norm: str = "layer_norm", image_size: int = 224, max_length: int = 8192, max_length_train: int = 8192, use_flash_attn: bool = True, use_cache: bool = True, num_attention_heads: int = 16, num_hidden_layers: int = 24, vocab_size: int = 49152, hidden_size: int = 2048, num_kv_heads: int = 4, torch_dtype: str = "bfloat16", **kwargs, ): self.starcoder_model_name = starcoder_model_name self.image_encoder_type = image_encoder_type self.adapter_norm = adapter_norm self.image_size = image_size self.max_length = max_length self.max_length_train = max_length_train self.use_flash_attn = use_flash_attn self.use_cache = use_cache self.num_attention_heads = num_attention_heads self.num_hidden_layers = num_hidden_layers self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_kv_heads = num_kv_heads self.torch_dtype = torch_dtype super().__init__(**kwargs) class StarVectorForCausalLM(PreTrainedModel): config_class = StarVectorConfig _no_split_modules = [] def __init__(self, config: StarVectorConfig, **kwargs): super().__init__(config) starcoder_model_name = config.starcoder_model_name if 'starcoder2' in starcoder_model_name: from starvector.model.models.starvector_v2 import StarVectorStarCoder2 self.model = StarVectorStarCoder2(config=config, **kwargs) else: from starvector.model.models.starvector_v1 import StarVectorStarCoder self.model = StarVectorStarCoder(config=config, **kwargs) def forward(self, batch): return self.model(batch) def generate_im2svg(self, batch, **kwargs): return self.model.generate_im2svg(batch, **kwargs) def generate_im2text(self, batch, **kwargs): return self.model.generate_im2text(batch, **kwargs) def process_images(self, images): return self.model.image_encoder.process_images(images)