Update starvector_arch.py
Browse files- starvector_arch.py +164 -3
starvector_arch.py
CHANGED
@@ -2,6 +2,81 @@ from transformers import (
|
|
2 |
PretrainedConfig,
|
3 |
PreTrainedModel
|
4 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class StarVectorConfig(PretrainedConfig):
|
7 |
model_type = "starvector"
|
@@ -24,6 +99,7 @@ class StarVectorConfig(PretrainedConfig):
|
|
24 |
torch_dtype: str = "bfloat16",
|
25 |
**kwargs,
|
26 |
):
|
|
|
27 |
self.starcoder_model_name = starcoder_model_name
|
28 |
self.image_encoder_type = image_encoder_type
|
29 |
self.adapter_norm = adapter_norm
|
@@ -37,7 +113,6 @@ class StarVectorConfig(PretrainedConfig):
|
|
37 |
self.vocab_size = vocab_size
|
38 |
self.hidden_size = hidden_size
|
39 |
self.num_kv_heads = num_kv_heads
|
40 |
-
self.torch_dtype = torch_dtype
|
41 |
super().__init__(**kwargs)
|
42 |
|
43 |
class StarVectorForCausalLM(PreTrainedModel):
|
@@ -54,8 +129,94 @@ class StarVectorForCausalLM(PreTrainedModel):
|
|
54 |
from starvector.model.models.starvector_v1 import StarVectorStarCoder
|
55 |
self.model = StarVectorStarCoder(config=config, **kwargs)
|
56 |
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def generate_im2svg(self, batch, **kwargs):
|
61 |
return self.model.generate_im2svg(batch, **kwargs)
|
|
|
2 |
PretrainedConfig,
|
3 |
PreTrainedModel
|
4 |
)
|
5 |
+
from torch.nn import CrossEntropyLoss
|
6 |
+
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import CausalLMOutputWithCrossAttentions
|
7 |
+
from typing import Optional, Tuple, Union
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from transformers.processing_utils import ProcessorMixin
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode, pad
|
13 |
+
from transformers.feature_extraction_sequence_utils import BatchFeature
|
14 |
+
from transformers import AutoProcessor
|
15 |
+
|
16 |
+
class SimpleStarVectorProcessor(ProcessorMixin):
|
17 |
+
attributes = ["tokenizer"] # Only include tokenizer in attributes
|
18 |
+
valid_kwargs = ["size", "mean", "std"] # Add other parameters as valid kwargs
|
19 |
+
image_processor_class = "AutoImageProcessor"
|
20 |
+
tokenizer_class = "AutoTokenizer"
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
tokenizer=None, # Make tokenizer the first argument
|
24 |
+
size=224,
|
25 |
+
mean=None,
|
26 |
+
std=None,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
if mean is None:
|
30 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
31 |
+
if std is None:
|
32 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
33 |
+
|
34 |
+
# Store these as instance variables
|
35 |
+
self.mean = mean
|
36 |
+
self.std = std
|
37 |
+
self.size = size
|
38 |
+
|
39 |
+
self.normalize = transforms.Normalize(mean=mean, std=std)
|
40 |
+
|
41 |
+
self.transform = transforms.Compose([
|
42 |
+
transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img),
|
43 |
+
transforms.Lambda(lambda img: self._pad_to_square(img)),
|
44 |
+
transforms.Resize(size, interpolation=InterpolationMode.BICUBIC),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
self.normalize
|
47 |
+
])
|
48 |
+
|
49 |
+
# Initialize parent class with tokenizer
|
50 |
+
super().__init__(tokenizer=tokenizer)
|
51 |
+
|
52 |
+
|
53 |
+
def __call__(self, images=None, text=None, **kwargs) -> BatchFeature:
|
54 |
+
"""
|
55 |
+
Process images and/or text inputs.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
images: Optional image input(s)
|
59 |
+
text: Optional text input(s)
|
60 |
+
**kwargs: Additional arguments
|
61 |
+
"""
|
62 |
+
if images is None and text is None:
|
63 |
+
raise ValueError("You have to specify at least one of `images` or `text`.")
|
64 |
+
|
65 |
+
image_inputs = {}
|
66 |
+
if images is not None:
|
67 |
+
if isinstance(images, (list, tuple)):
|
68 |
+
images_ = [self.transform(img) for img in images]
|
69 |
+
else:
|
70 |
+
images_ = self.transform(images)
|
71 |
+
image_inputs = {"pixel_values": images_}
|
72 |
+
|
73 |
+
text_inputs = {}
|
74 |
+
if text is not None:
|
75 |
+
text_inputs = self.tokenizer(text, **kwargs)
|
76 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
77 |
+
|
78 |
+
AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor)
|
79 |
+
|
80 |
|
81 |
class StarVectorConfig(PretrainedConfig):
|
82 |
model_type = "starvector"
|
|
|
99 |
torch_dtype: str = "bfloat16",
|
100 |
**kwargs,
|
101 |
):
|
102 |
+
kwargs["torch_dtype"] = torch_dtype
|
103 |
self.starcoder_model_name = starcoder_model_name
|
104 |
self.image_encoder_type = image_encoder_type
|
105 |
self.adapter_norm = adapter_norm
|
|
|
113 |
self.vocab_size = vocab_size
|
114 |
self.hidden_size = hidden_size
|
115 |
self.num_kv_heads = num_kv_heads
|
|
|
116 |
super().__init__(**kwargs)
|
117 |
|
118 |
class StarVectorForCausalLM(PreTrainedModel):
|
|
|
129 |
from starvector.model.models.starvector_v1 import StarVectorStarCoder
|
130 |
self.model = StarVectorStarCoder(config=config, **kwargs)
|
131 |
|
132 |
+
@property
|
133 |
+
def supports_gradient_checkpointing(self):
|
134 |
+
# If the underlying transformer (e.g., the one in StarCoderModel)
|
135 |
+
# supports gradient checkpointing, delegate to it.
|
136 |
+
if hasattr(self.model, 'svg_transformer'):
|
137 |
+
return getattr(self.model.svg_transformer, 'supports_gradient_checkpointing', False)
|
138 |
+
return False
|
139 |
+
|
140 |
+
def gradient_checkpointing_enable(self):
|
141 |
+
# Optionally, forward this call to the internal transformer.
|
142 |
+
if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
|
143 |
+
self.model.svg_transformer.gradient_checkpointing_enable()
|
144 |
+
|
145 |
+
def forward(
|
146 |
+
self,
|
147 |
+
input_ids: Optional[torch.Tensor] = None,
|
148 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
149 |
+
attention_mask: Optional[torch.Tensor] = None,
|
150 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
151 |
+
position_ids: Optional[torch.Tensor] = None,
|
152 |
+
head_mask: Optional[torch.Tensor] = None,
|
153 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
154 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
155 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
156 |
+
labels: Optional[torch.Tensor] = None,
|
157 |
+
use_cache: Optional[bool] = None,
|
158 |
+
output_attentions: Optional[bool] = None,
|
159 |
+
output_hidden_states: Optional[bool] = None,
|
160 |
+
return_dict: Optional[bool] = None,
|
161 |
+
num_logits_to_keep: int = 0,
|
162 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
163 |
+
r"""
|
164 |
+
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
165 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
166 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
167 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
168 |
+
"""
|
169 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
170 |
+
|
171 |
+
transformer_outputs = self.model.svg_transformer.transformer(
|
172 |
+
input_ids,
|
173 |
+
past_key_values=past_key_values,
|
174 |
+
attention_mask=attention_mask,
|
175 |
+
token_type_ids=token_type_ids,
|
176 |
+
position_ids=position_ids,
|
177 |
+
head_mask=head_mask,
|
178 |
+
inputs_embeds=inputs_embeds,
|
179 |
+
encoder_hidden_states=encoder_hidden_states,
|
180 |
+
encoder_attention_mask=encoder_attention_mask,
|
181 |
+
use_cache=use_cache,
|
182 |
+
output_attentions=output_attentions,
|
183 |
+
output_hidden_states=output_hidden_states,
|
184 |
+
return_dict=return_dict,
|
185 |
+
)
|
186 |
+
hidden_states = transformer_outputs[0]
|
187 |
+
|
188 |
+
# If GRPO requested only the last tokens, slice accordingly.
|
189 |
+
if num_logits_to_keep > 0:
|
190 |
+
lm_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
191 |
+
else:
|
192 |
+
lm_logits = self.lm_head(hidden_states)
|
193 |
+
|
194 |
+
# lm_logits = self.lm_head(hidden_states)
|
195 |
+
|
196 |
+
loss = None
|
197 |
+
if labels is not None:
|
198 |
+
# Shift so that tokens < n predict n
|
199 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
200 |
+
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
|
201 |
+
# Flatten the tokens
|
202 |
+
loss_fct = CrossEntropyLoss()
|
203 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
204 |
+
|
205 |
+
if not return_dict:
|
206 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
207 |
+
return ((loss,) + output) if loss is not None else output
|
208 |
+
|
209 |
+
return CausalLMOutputWithCrossAttentions(
|
210 |
+
loss=loss,
|
211 |
+
logits=lm_logits,
|
212 |
+
past_key_values=transformer_outputs.past_key_values,
|
213 |
+
hidden_states=transformer_outputs.hidden_states,
|
214 |
+
attentions=transformer_outputs.attentions,
|
215 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
216 |
+
)
|
217 |
+
|
218 |
+
# def forward(self, batch):
|
219 |
+
# return self.model(batch)
|
220 |
|
221 |
def generate_im2svg(self, batch, **kwargs):
|
222 |
return self.model.generate_im2svg(batch, **kwargs)
|