File size: 2,416 Bytes
713773e
 
 
 
4760193
713773e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4760193
 
713773e
4760193
713773e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4760193
713773e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from transformers import (
    PretrainedConfig,
    PreTrainedModel
)
import torch
class StarVectorConfig(PretrainedConfig):
    model_type = "starvector"

    def __init__(
        self,
        starcoder_model_name: str = "bigcode/starcoderbase-1b",
        image_encoder_type: str = "clip",
        adapter_norm: str = "layer_norm",
        image_size: int = 224,
        max_length: int = 8192,
        max_length_train: int = 8192,
        use_flash_attn: bool = True,
        use_cache: bool = True,
        num_attention_heads: int = 16,
        num_hidden_layers: int = 24,
        vocab_size: int = 49152,
        hidden_size: int = 2048,
        num_kv_heads: int = 4,
        torch_dtype: str = "bfloat16",
        **kwargs,
    ):
        self.starcoder_model_name = starcoder_model_name
        self.image_encoder_type = image_encoder_type
        self.adapter_norm = adapter_norm
        self.image_size = image_size
        self.max_length = max_length
        self.max_length_train = max_length_train
        self.use_flash_attn = use_flash_attn
        self.use_cache = use_cache
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_kv_heads = num_kv_heads
        self.torch_dtype = torch_dtype
        super().__init__(**kwargs)

class StarVectorForCausalLM(PreTrainedModel):
    config_class = StarVectorConfig
    _no_split_modules = []

    def __init__(self, config: StarVectorConfig, **kwargs):
        super().__init__(config)
        starcoder_model_name = config.starcoder_model_name
        if 'starcoder2' in starcoder_model_name:
            from starvector.model.models.starvector_v2 import StarVectorStarCoder2
            self.model = StarVectorStarCoder2(config=config, **kwargs)
        else:
            from starvector.model.models.starvector_v1 import StarVectorStarCoder
            self.model = StarVectorStarCoder(config=config, **kwargs)

    def forward(self, batch):
        return self.model(batch)

    def generate_im2svg(self, batch, **kwargs):
        return self.model.generate_im2svg(batch, **kwargs)
    
    def generate_im2text(self, batch, **kwargs):
        return self.model.generate_im2text(batch, **kwargs)

    def process_images(self, images):
        return self.model.image_encoder.process_images(images)