File size: 3,059 Bytes
40d80f1
 
 
41ad712
40d80f1
 
 
 
 
 
 
41ad712
 
40d80f1
 
41ad712
 
40d80f1
 
 
 
 
 
 
 
41ad712
40d80f1
 
 
 
 
 
 
 
 
 
 
 
 
 
b6039bb
 
 
 
40d80f1
0412630
 
 
 
 
 
 
 
 
 
 
 
 
 
40d80f1
 
 
 
41ad712
40d80f1
 
 
 
41ad712
40d80f1
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from transformers import ProcessorMixin, AutoProcessor
from transformers.models.auto.processing_auto import AutoProcessor
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import BatchEncoding
import json
import os

class FlamingoProcessor(ProcessorMixin):
    """
    Custom processor that combines a tokenizer and feature extractor.
    """
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"
    
    def __init__(self, image_processor, tokenizer):
        super().__init__(image_processor, tokenizer)
    
    def __call__(self, text=None, images=None, **kwargs):
        """
        Main processing method that handles both text and images.
        
        Args:
            text: Text input(s) to tokenize
            images: Image input(s) to process
            **kwargs: Additional arguments passed to tokenizer/image_processor
        
        Returns:
            Dictionary with processed inputs
        """
        if text is None and images is None:
            raise ValueError("You need to specify either text or images")
        
        encoding = {}
        
        # Process text if provided
        if text is not None:
            if type(text) == str:
                all_text = "<image> " + text
            else:
                if type(text[0]) == str:
                    all_text = ["<image> " + _text for _text in text]
                else:
                    all_text = ['<image> ' + " ".join(_text) for _text in text]
            text_encoding = self.tokenizer(all_text, **kwargs)

            if 'offset_mapping' in text_encoding:
                offset_mapping = text_encoding['offset_mapping']
                if type(offset_mapping) != list:
                    offset_mapping = offset_mapping[0].tolist()
                true_offset = offset_mapping[0][-1]
                new_offsets = []
                for start, end in offset_mapping:
                    if start == 0:
                        new_offsets.append((0, 0))
                    else:
                        new_offsets.append((start - true_offset, end - true_offset))
                text_encoding['offset_mapping'] = new_offsets

            encoding.update(text_encoding)
        
        # Process images if provided
        if images is not None:
            image_encoding = self.image_processor(images, **kwargs)
            # Add prefix to avoid key conflicts
            for key, value in image_encoding.items():
                encoding[f"pixel_values" if key == "pixel_values" else f"image_{key}"] = value
        
        return BatchEncoding(encoding)
    
    def batch_decode(self, *args, **kwargs):
        """
        Delegate batch decoding to the tokenizer.
        """
        return self.tokenizer.batch_decode(*args, **kwargs)
    
    def decode(self, *args, **kwargs):
        """
        Delegate decoding to the tokenizer.
        """
        return self.tokenizer.decode(*args, **kwargs)