File size: 2,938 Bytes
9d9ac6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dce64a
 
 
 
 
 
 
 
 
 
81663e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers.processing_utils import ProcessorMixin
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode, pad
from transformers.feature_extraction_sequence_utils import BatchFeature

class SimpleStarVectorProcessor(ProcessorMixin):
    attributes = ["tokenizer"]  # Only include tokenizer in attributes
    valid_kwargs = ["size", "mean", "std"]  # Add other parameters as valid kwargs
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, 
                 tokenizer=None,  # Make tokenizer the first argument
                 size=224, 
                 mean=None, 
                 std=None, 
                 **kwargs,
                 ):
        if mean is None:
            mean = (0.48145466, 0.4578275, 0.40821073)
        if std is None:
            std = (0.26862954, 0.26130258, 0.27577711)

        # Store these as instance variables
        self.mean = mean
        self.std = std
        self.size = size
        
        self.normalize = transforms.Normalize(mean=mean, std=std)        
        
        self.transform = transforms.Compose([
            transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img),
            transforms.Lambda(lambda img: self._pad_to_square(img)),
            transforms.Resize(size, interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            self.normalize
        ])

        # Initialize parent class with tokenizer
        super().__init__(tokenizer=tokenizer)


    def __call__(self, images=None, text=None, **kwargs) -> BatchFeature:
        """
        Process images and/or text inputs.
        
        Args:
            images: Optional image input(s)
            text: Optional text input(s)
            **kwargs: Additional arguments
        """
        if images is None and text is None:
            raise ValueError("You have to specify at least one of `images` or `text`.")

        image_inputs = {}
        if images is not None:
            if isinstance(images, (list, tuple)):
                images_ = [self.transform(img) for img in images]
            else:
                images_ = self.transform(images)
            image_inputs = {"pixel_values": images_}
        
        text_inputs = {}
        if text is not None:
            text_inputs = self.tokenizer(text, **kwargs)
        return BatchFeature(data={**text_inputs, **image_inputs})

    def _pad_to_square(self, img):
        # Calculate padding to make the image square
        width, height = img.size
        max_dim = max(width, height)
        padding = [(max_dim - width) // 2, (max_dim - height) // 2]
        padding += [max_dim - width - padding[0], max_dim - height - padding[1]]
        return pad(img, padding, fill=255)  # Assuming white padding


AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor)