from transformers import ( PretrainedConfig, PreTrainedModel ) from torch.nn import CrossEntropyLoss from transformers.models.gpt_bigcode.modeling_gpt_bigcode import CausalLMOutputWithCrossAttentions from typing import Optional, Tuple, Union import torch from transformers.processing_utils import ProcessorMixin from torchvision import transforms from torchvision.transforms.functional import InterpolationMode, pad from transformers.feature_extraction_sequence_utils import BatchFeature from transformers import AutoProcessor class SimpleStarVectorProcessor(ProcessorMixin): attributes = ["tokenizer"] # Only include tokenizer in attributes valid_kwargs = ["size", "mean", "std"] # Add other parameters as valid kwargs image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__(self, tokenizer=None, # Make tokenizer the first argument size=224, mean=None, std=None, **kwargs, ): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) # Store these as instance variables self.mean = mean self.std = std self.size = size self.normalize = transforms.Normalize(mean=mean, std=std) self.transform = transforms.Compose([ transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img), transforms.Lambda(lambda img: self._pad_to_square(img)), transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), self.normalize ]) # Initialize parent class with tokenizer super().__init__(tokenizer=tokenizer) def __call__(self, images=None, text=None, **kwargs) -> BatchFeature: """ Process images and/or text inputs. Args: images: Optional image input(s) text: Optional text input(s) **kwargs: Additional arguments """ if images is None and text is None: raise ValueError("You have to specify at least one of `images` or `text`.") image_inputs = {} if images is not None: if isinstance(images, (list, tuple)): images_ = [self.transform(img) for img in images] else: images_ = self.transform(images) image_inputs = {"pixel_values": images_} text_inputs = {} if text is not None: text_inputs = self.tokenizer(text, **kwargs) return BatchFeature(data={**text_inputs, **image_inputs}) AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor) class StarVectorConfig(PretrainedConfig): model_type = "starvector" def __init__( self, starcoder_model_name: str = "bigcode/starcoderbase-1b", image_encoder_type: str = "clip", adapter_norm: str = "layer_norm", image_size: int = 224, max_length: int = 8192, max_length_train: int = 8192, use_flash_attn: bool = True, use_cache: bool = True, num_attention_heads: int = 16, num_hidden_layers: int = 24, vocab_size: int = 49152, hidden_size: int = 2048, num_kv_heads: int = 4, torch_dtype: str = "bfloat16", **kwargs, ): kwargs["torch_dtype"] = torch_dtype self.starcoder_model_name = starcoder_model_name self.image_encoder_type = image_encoder_type self.adapter_norm = adapter_norm self.image_size = image_size self.max_length = max_length self.max_length_train = max_length_train self.use_flash_attn = use_flash_attn self.use_cache = use_cache self.num_attention_heads = num_attention_heads self.num_hidden_layers = num_hidden_layers self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_kv_heads = num_kv_heads super().__init__(**kwargs) class StarVectorForCausalLM(PreTrainedModel): config_class = StarVectorConfig _no_split_modules = [] def __init__(self, config: StarVectorConfig, **kwargs): super().__init__(config) starcoder_model_name = config.starcoder_model_name if 'starcoder2' in starcoder_model_name: from starvector.model.models.starvector_v2 import StarVectorStarCoder2 self.model = StarVectorStarCoder2(config=config, **kwargs) else: from starvector.model.models.starvector_v1 import StarVectorStarCoder self.model = StarVectorStarCoder(config=config, **kwargs) @property def supports_gradient_checkpointing(self): # If the underlying transformer (e.g., the one in StarCoderModel) # supports gradient checkpointing, delegate to it. if hasattr(self.model, 'svg_transformer'): return getattr(self.model.svg_transformer, 'supports_gradient_checkpointing', False) return False def gradient_checkpointing_enable(self): # Optionally, forward this call to the internal transformer. if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'): self.model.svg_transformer.gradient_checkpointing_enable() def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model.svg_transformer.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] # If GRPO requested only the last tokens, slice accordingly. if num_logits_to_keep > 0: lm_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) else: lm_logits = self.lm_head(hidden_states) # lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, ) # def forward(self, batch): # return self.model(batch) def generate_im2svg(self, batch, **kwargs): return self.model.generate_im2svg(batch, **kwargs) def generate_im2text(self, batch, **kwargs): return self.model.generate_im2text(batch, **kwargs) def process_images(self, images): return self.model.image_encoder.process_images(images)