|
from transformers.processing_utils import ProcessorMixin |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode, pad |
|
from transformers.feature_extraction_sequence_utils import BatchFeature |
|
|
|
class SimpleStarVectorProcessor(ProcessorMixin): |
|
attributes = ["tokenizer"] |
|
valid_kwargs = ["size", "mean", "std"] |
|
image_processor_class = "AutoImageProcessor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__(self, |
|
tokenizer=None, |
|
size=224, |
|
mean=None, |
|
std=None, |
|
**kwargs, |
|
): |
|
if mean is None: |
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
if std is None: |
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
self.mean = mean |
|
self.std = std |
|
self.size = size |
|
|
|
self.normalize = transforms.Normalize(mean=mean, std=std) |
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img), |
|
transforms.Lambda(lambda img: self._pad_to_square(img)), |
|
transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
self.normalize |
|
]) |
|
|
|
|
|
super().__init__(tokenizer=tokenizer) |
|
|
|
|
|
def __call__(self, images=None, text=None, **kwargs) -> BatchFeature: |
|
""" |
|
Process images and/or text inputs. |
|
|
|
Args: |
|
images: Optional image input(s) |
|
text: Optional text input(s) |
|
**kwargs: Additional arguments |
|
""" |
|
if images is None and text is None: |
|
raise ValueError("You have to specify at least one of `images` or `text`.") |
|
|
|
image_inputs = {} |
|
if images is not None: |
|
if isinstance(images, (list, tuple)): |
|
images_ = [self.transform(img) for img in images] |
|
else: |
|
images_ = self.transform(images) |
|
image_inputs = {"pixel_values": images_} |
|
|
|
text_inputs = {} |
|
if text is not None: |
|
text_inputs = self.tokenizer(text, **kwargs) |
|
return BatchFeature(data={**text_inputs, **image_inputs}) |
|
|
|
def _pad_to_square(self, img): |
|
|
|
width, height = img.size |
|
max_dim = max(width, height) |
|
padding = [(max_dim - width) // 2, (max_dim - height) // 2] |
|
padding += [max_dim - width - padding[0], max_dim - height - padding[1]] |
|
return pad(img, padding, fill=255) |
|
|
|
|
|
AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor) |
|
|