from transformers.processing_utils import ProcessorMixin from torchvision import transforms from torchvision.transforms.functional import InterpolationMode, pad from transformers.feature_extraction_sequence_utils import BatchFeature class SimpleStarVectorProcessor(ProcessorMixin): attributes = ["tokenizer"] # Only include tokenizer in attributes valid_kwargs = ["size", "mean", "std"] # Add other parameters as valid kwargs image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__(self, tokenizer=None, # Make tokenizer the first argument size=224, mean=None, std=None, **kwargs, ): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) # Store these as instance variables self.mean = mean self.std = std self.size = size self.normalize = transforms.Normalize(mean=mean, std=std) self.transform = transforms.Compose([ transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img), transforms.Lambda(lambda img: self._pad_to_square(img)), transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), self.normalize ]) # Initialize parent class with tokenizer super().__init__(tokenizer=tokenizer) def __call__(self, images=None, text=None, **kwargs) -> BatchFeature: """ Process images and/or text inputs. Args: images: Optional image input(s) text: Optional text input(s) **kwargs: Additional arguments """ if images is None and text is None: raise ValueError("You have to specify at least one of `images` or `text`.") image_inputs = {} if images is not None: if isinstance(images, (list, tuple)): images_ = [self.transform(img) for img in images] else: images_ = self.transform(images) image_inputs = {"pixel_values": images_} text_inputs = {} if text is not None: text_inputs = self.tokenizer(text, **kwargs) return BatchFeature(data={**text_inputs, **image_inputs}) def _pad_to_square(self, img): # Calculate padding to make the image square width, height = img.size max_dim = max(width, height) padding = [(max_dim - width) // 2, (max_dim - height) // 2] padding += [max_dim - width - padding[0], max_dim - height - padding[1]] return pad(img, padding, fill=255) # Assuming white padding AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor)