File size: 3,581 Bytes
66fd0f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

class Model(PreTrainedModel):
    config_class = VLMConfig
    
    def __init__(self, config: VLMConfig, image_model, language_model, num_projections: int, tokenizer, prepend_text: str, image_tokens:int):
        super().__init__(config)
        self.image_model = image_model
        self.language_model = language_model
        self.projector = nn.Sequential(
            *projection_layers(image_model.num_features, language_model.config.hidden_size, num_projections)
        )
        
        self.tokenizer = tokenizer
        self.eos_token = tokenizer.eos_token
        self.prepend_text = prepend_text
        
        self.image_tokens = image_tokens
        
        input_ids = tokenizer(prepend_text, return_tensors="pt").input_ids
        eos_token_index = (input_ids[0] == tokenizer.eos_token_id).nonzero(as_tuple=True)[0].item()
        text_embeddings = self.language_model.get_input_embeddings()(input_ids).detach()
        self.prepend_embeddings = text_embeddings[:, :eos_token_index]
        self.postpend_embeddings = text_embeddings[:, eos_token_index:]
        self.attention_mask = torch.ones(1, text_embeddings.shape[1] + image_tokens)
        self.labels = torch.full((1, self.attention_mask.shape[1]), LABEL_MASK)
        
    def project_image_features(self, images: torch.Tensor):
        image_features = self.image_model.forward_features(images)
        image_features = einops.rearrange(image_features, "bs dim w h -> bs (w h) dim")
        encoder_outputs = self.projector(image_features)
        return encoder_outputs
        
    def forward(self, images: torch.Tensor, tokenized_captions: dict[str, torch.Tensor]):
        image_outputs = self.project_image_features(images)
        caption_embeddings = self.language_model.get_input_embeddings()(tokenized_captions.input_ids).detach()
        device = images.device
        embeddings = torch.cat(
            [
                self.prepend_embeddings.to(device).expand(len(images), -1, -1),
                image_outputs,
                self.postpend_embeddings.to(device).expand(len(images), -1, -1),
                caption_embeddings,
            ],
            dim=1,
        )
        attention_mask = torch.cat(
            [
                self.attention_mask.to(device).expand(len(images), -1), 
                tokenized_captions.attention_mask
            ], 
            dim=1
        )
        labels = torch.cat(
            [
                self.labels.to(device).expand(len(images), -1), 
                tokenized_captions.input_ids.clone()
            ],
            dim=1,
        )
        labels[attention_mask == 0] = LABEL_MASK
        
        return self.language_model(
            inputs_embeds=embeddings,
            attention_mask=attention_mask,
            labels=labels,
        )
    
    def generate(self, images: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
        image_outputs = self.project_image_features(images)
        device = images.device
        embeddings = torch.cat(
            [
                self.prepend_embeddings.to(device).expand(len(images), -1, -1),
                image_outputs,
                self.postpend_embeddings.to(device).expand(len(images), -1, -1),
            ],
            dim=1,
        )
        attention_mask = self.attention_mask.to(device).expand(len(images), -1)
        return self.language_model.generate(
            inputs_embeds=embeddings,
            attention_mask=attention_mask,
            eos_token_id=self.tokenizer.eos_token_id,
            **generator_kwargs
        )