|
from transformers import ( |
|
PretrainedConfig, |
|
PreTrainedModel |
|
) |
|
from torch.nn import CrossEntropyLoss |
|
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import CausalLMOutputWithCrossAttentions |
|
from typing import Optional, Tuple, Union |
|
import torch |
|
|
|
from transformers.processing_utils import ProcessorMixin |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode, pad |
|
from transformers.feature_extraction_sequence_utils import BatchFeature |
|
from transformers import AutoProcessor |
|
|
|
class SimpleStarVectorProcessor(ProcessorMixin): |
|
attributes = ["tokenizer"] |
|
valid_kwargs = ["size", "mean", "std"] |
|
image_processor_class = "AutoImageProcessor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__(self, |
|
tokenizer=None, |
|
size=224, |
|
mean=None, |
|
std=None, |
|
**kwargs, |
|
): |
|
if mean is None: |
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
if std is None: |
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
self.mean = mean |
|
self.std = std |
|
self.size = size |
|
|
|
self.normalize = transforms.Normalize(mean=mean, std=std) |
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img), |
|
transforms.Lambda(lambda img: self._pad_to_square(img)), |
|
transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
self.normalize |
|
]) |
|
|
|
|
|
super().__init__(tokenizer=tokenizer) |
|
|
|
|
|
def __call__(self, images=None, text=None, **kwargs) -> BatchFeature: |
|
""" |
|
Process images and/or text inputs. |
|
|
|
Args: |
|
images: Optional image input(s) |
|
text: Optional text input(s) |
|
**kwargs: Additional arguments |
|
""" |
|
if images is None and text is None: |
|
raise ValueError("You have to specify at least one of `images` or `text`.") |
|
|
|
image_inputs = {} |
|
if images is not None: |
|
if isinstance(images, (list, tuple)): |
|
images_ = [self.transform(img) for img in images] |
|
else: |
|
images_ = self.transform(images) |
|
image_inputs = {"pixel_values": images_} |
|
|
|
text_inputs = {} |
|
if text is not None: |
|
text_inputs = self.tokenizer(text, **kwargs) |
|
return BatchFeature(data={**text_inputs, **image_inputs}) |
|
|
|
AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor) |
|
|
|
|
|
class StarVectorConfig(PretrainedConfig): |
|
model_type = "starvector" |
|
|
|
def __init__( |
|
self, |
|
starcoder_model_name: str = "bigcode/starcoderbase-1b", |
|
image_encoder_type: str = "clip", |
|
adapter_norm: str = "layer_norm", |
|
image_size: int = 224, |
|
max_length: int = 8192, |
|
max_length_train: int = 8192, |
|
use_flash_attn: bool = True, |
|
use_cache: bool = True, |
|
num_attention_heads: int = 16, |
|
num_hidden_layers: int = 24, |
|
vocab_size: int = 49152, |
|
hidden_size: int = 2048, |
|
num_kv_heads: int = 4, |
|
torch_dtype: str = "bfloat16", |
|
**kwargs, |
|
): |
|
kwargs["torch_dtype"] = torch_dtype |
|
self.starcoder_model_name = starcoder_model_name |
|
self.image_encoder_type = image_encoder_type |
|
self.adapter_norm = adapter_norm |
|
self.image_size = image_size |
|
self.max_length = max_length |
|
self.max_length_train = max_length_train |
|
self.use_flash_attn = use_flash_attn |
|
self.use_cache = use_cache |
|
self.num_attention_heads = num_attention_heads |
|
self.num_hidden_layers = num_hidden_layers |
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_kv_heads = num_kv_heads |
|
super().__init__(**kwargs) |
|
|
|
class StarVectorForCausalLM(PreTrainedModel): |
|
config_class = StarVectorConfig |
|
_no_split_modules = [] |
|
|
|
def __init__(self, config: StarVectorConfig, **kwargs): |
|
super().__init__(config) |
|
starcoder_model_name = config.starcoder_model_name |
|
if 'starcoder2' in starcoder_model_name: |
|
from starvector.model.models.starvector_v2 import StarVectorStarCoder2 |
|
self.model = StarVectorStarCoder2(config=config, **kwargs) |
|
else: |
|
from starvector.model.models.starvector_v1 import StarVectorStarCoder |
|
self.model = StarVectorStarCoder(config=config, **kwargs) |
|
|
|
@property |
|
def supports_gradient_checkpointing(self): |
|
|
|
|
|
if hasattr(self.model, 'svg_transformer'): |
|
return getattr(self.model.svg_transformer, 'supports_gradient_checkpointing', False) |
|
return False |
|
|
|
def gradient_checkpointing_enable(self): |
|
|
|
if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'): |
|
self.model.svg_transformer.gradient_checkpointing_enable() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_logits_to_keep: int = 0, |
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: |
|
r""" |
|
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
transformer_outputs = self.model.svg_transformer.transformer( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
|
|
|
|
if num_logits_to_keep > 0: |
|
lm_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) |
|
else: |
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) |
|
|
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + transformer_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
cross_attentions=transformer_outputs.cross_attentions, |
|
) |
|
|
|
|
|
|
|
|
|
def generate_im2svg(self, batch, **kwargs): |
|
return self.model.generate_im2svg(batch, **kwargs) |
|
|
|
def generate_im2text(self, batch, **kwargs): |
|
return self.model.generate_im2text(batch, **kwargs) |
|
|
|
def process_images(self, images): |
|
return self.model.image_encoder.process_images(images) |
|
|
|
|