|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tokenization classes and Image processor class, Processor class for Ernie_45T_VL.""" |
|
|
|
import copy |
|
import io |
|
import os |
|
import re |
|
import math |
|
import random |
|
import requests |
|
import base64 |
|
import datetime |
|
import hashlib |
|
import threading |
|
import uuid |
|
import decord |
|
from shutil import copyfile |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageDraw, ImageFont |
|
from PIL.ExifTags import TAGS |
|
from collections import defaultdict |
|
from typing import Any, Dict, List, Union |
|
from pathlib import Path |
|
from tempfile import NamedTemporaryFile as ntf |
|
|
|
try: |
|
|
|
import moviepy.editor as mp |
|
except: |
|
|
|
import moviepy as mp |
|
|
|
import sentencepiece as spm |
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
from transformers.tokenization_utils_base import ( |
|
PaddingStrategy, |
|
TextInput, |
|
) |
|
from transformers.utils import logging |
|
from transformers.utils import TensorType, logging |
|
from transformers.video_utils import VideoInput |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
from transformers.image_transforms import ( |
|
convert_to_rgb, |
|
normalize, |
|
rescale, |
|
resize, |
|
to_channel_dimension_format, |
|
) |
|
from transformers.image_utils import ( |
|
OPENAI_CLIP_MEAN, |
|
OPENAI_CLIP_STD, |
|
ChannelDimension, |
|
ImageInput, |
|
PILImageResampling, |
|
get_image_size, |
|
infer_channel_dimension_format, |
|
is_valid_image, |
|
make_list_of_images, |
|
to_numpy_array, |
|
valid_images, |
|
) |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def round_by_factor(number: int, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
|
|
def ceil_by_factor(number: int, factor: int) -> int: |
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.ceil(number / factor) * factor |
|
|
|
|
|
def floor_by_factor(number: int, factor: int) -> int: |
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.floor(number / factor) * factor |
|
|
|
|
|
def smart_resize( |
|
height: int, |
|
width: int, |
|
factor: int = 28, |
|
min_pixels: int = 4 * 28 * 28, |
|
max_pixels: int = 16384 * 28 * 28, |
|
): |
|
""" |
|
Rescales the image so that the following conditions are met: |
|
|
|
1. Both dimensions (height and width) are divisible by 'factor'. |
|
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
|
|
|
3. The aspect ratio of the image is maintained as closely as possible. |
|
""" |
|
MAX_RATIO = 200 |
|
if max(height, width) / min(height, width) > MAX_RATIO: |
|
if height > width: |
|
new_width = max(factor, round_by_factor(width, factor)) |
|
new_height = floor_by_factor(new_width * MAX_RATIO, factor) |
|
else: |
|
new_height = max(factor, round_by_factor(height, factor)) |
|
new_width = floor_by_factor(new_height * MAX_RATIO, factor) |
|
|
|
logger.info( |
|
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)},\ |
|
resize to {max(new_height, new_width) / min(new_height, new_width)}" |
|
) |
|
|
|
height = new_height |
|
width = new_width |
|
|
|
h_bar = max(factor, round_by_factor(height, factor)) |
|
w_bar = max(factor, round_by_factor(width, factor)) |
|
if h_bar * w_bar > max_pixels: |
|
beta = math.sqrt((height * width) / max_pixels) |
|
h_bar = floor_by_factor(height / beta, factor) |
|
w_bar = floor_by_factor(width / beta, factor) |
|
elif h_bar * w_bar < min_pixels: |
|
beta = math.sqrt(min_pixels / (height * width)) |
|
h_bar = ceil_by_factor(height * beta, factor) |
|
w_bar = ceil_by_factor(width * beta, factor) |
|
|
|
if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: |
|
raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") |
|
|
|
return h_bar, w_bar |
|
|
|
|
|
def is_scaled_image(image: np.ndarray) -> bool: |
|
""" |
|
Checks to see whether the pixel values have already been rescaled to [0, 1]. |
|
""" |
|
if image.dtype == np.uint8: |
|
return False |
|
|
|
|
|
return np.min(image) >= 0 and np.max(image) <= 1 |
|
|
|
|
|
def make_batched_images(images) -> List[List[ImageInput]]: |
|
""" |
|
Accepts images in list or nested list format, and makes a list of images for preprocessing. |
|
|
|
Args: |
|
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): |
|
The input image. |
|
|
|
Returns: |
|
list: A list of images. |
|
""" |
|
if ( |
|
isinstance(images, (list, tuple)) |
|
and isinstance(images[0], (list, tuple)) |
|
and is_valid_image(images[0][0]) |
|
): |
|
return [img for img_list in images for img in img_list] |
|
|
|
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): |
|
return images |
|
|
|
elif is_valid_image(images): |
|
return [images] |
|
|
|
raise ValueError(f"Could not make batched images from {images}") |
|
|
|
|
|
|
|
def make_batched_videos(videos) -> List[VideoInput]: |
|
"""dummy""" |
|
if ( |
|
isinstance(videos, (list, tuple)) |
|
and isinstance(videos[0], (list, tuple)) |
|
and is_valid_image(videos[0][0]) |
|
): |
|
return videos |
|
|
|
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): |
|
if isinstance(videos[0], Image.Image): |
|
return [videos] |
|
elif len(videos[0].shape) == 4: |
|
return [list(video) for video in videos] |
|
|
|
elif is_valid_image(videos) and len(videos.shape) == 4: |
|
return [list(videos)] |
|
|
|
raise ValueError(f"Could not make batched video from {videos}") |
|
|
|
|
|
class Ernie_45T_VLImageProcessor(BaseImageProcessor): |
|
r""" |
|
Constructs a adaptive image processor that dynamically resizes images based on the original images. |
|
|
|
Args: |
|
do_resize (`bool`, *optional*, defaults to `True`): |
|
Whether to resize the image's (height, width) dimensions. |
|
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): |
|
Resampling filter to use when resizing the image. |
|
do_rescale (`bool`, *optional*, defaults to `True`): |
|
Whether to rescale the image by the specified scale `rescale_factor`. |
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): |
|
Scale factor to use if rescaling the image. |
|
do_normalize (`bool`, *optional*, defaults to `True`): |
|
Whether to normalize the image. |
|
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): |
|
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. |
|
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): |
|
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel |
|
in the image. |
|
do_convert_rgb (`bool`, *optional*, defaults to `True`): |
|
Whether to convert the image to RGB. |
|
min_pixels (`int`, *optional*, defaults to `56 * 56`): |
|
The min pixels of the image to resize the image. |
|
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): |
|
The max pixels of the image to resize the image. |
|
patch_size (`int`, *optional*, defaults to 14): |
|
The spacial patch size of the vision encoder. |
|
temporal_conv_size (`int`, *optional*, defaults to 2): |
|
The temporal conv size in resampler. |
|
merge_size (`int`, *optional*, defaults to 2): |
|
The merge size of the vision encoder to llm encoder. |
|
""" |
|
|
|
model_input_names = [ |
|
"pixel_values", |
|
"image_grid_thw", |
|
"pixel_values_videos", |
|
"video_grid_thw", |
|
] |
|
|
|
def __init__( |
|
self, |
|
do_resize: bool = True, |
|
resample: PILImageResampling = PILImageResampling.BICUBIC, |
|
do_rescale: bool = True, |
|
rescale_factor: Union[float, List[float]] = 1 / 255, |
|
do_normalize: bool = True, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
do_convert_rgb: bool = True, |
|
min_pixels: int = 56 * 56, |
|
max_pixels: int = 28 * 28 * 1280, |
|
patch_size: int = 14, |
|
temporal_conv_size: int = 2, |
|
merge_size: int = 2, |
|
**kwargs, |
|
) -> None: |
|
"""init""" |
|
super().__init__(**kwargs) |
|
self.do_resize = do_resize |
|
self.resample = resample |
|
self.do_rescale = do_rescale |
|
self.rescale_factor = rescale_factor |
|
self.do_normalize = do_normalize |
|
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN |
|
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD |
|
self.min_pixels = min_pixels |
|
self.max_pixels = max_pixels |
|
self.patch_size = patch_size |
|
self.temporal_conv_size = temporal_conv_size |
|
self.merge_size = merge_size |
|
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} |
|
self.do_convert_rgb = do_convert_rgb |
|
|
|
def set_pixels(self, min_pixels=None, max_pixels=None, msg=""): |
|
"""set_pixels""" |
|
if min_pixels is not None: |
|
assert ( |
|
isinstance(min_pixels, int) and min_pixels >= 0 |
|
), "min_pixels must be positive int" |
|
logger.info( |
|
f"{msg} Ernie_45T_VLImageProcessor set min_pixels = {min_pixels}" |
|
) |
|
self.min_pixels = min_pixels |
|
self.size["min_pixels"] = int(min_pixels) |
|
if max_pixels is not None: |
|
assert ( |
|
isinstance(max_pixels, int) and max_pixels > 0 |
|
), "max_pixels must be positive int" |
|
logger.info( |
|
f"{msg} Ernie_45T_VLImageProcessor set max_pixels = {max_pixels}" |
|
) |
|
self.max_pixels = max_pixels |
|
self.size["max_pixels"] = int(max_pixels) |
|
|
|
def get_smarted_resize(self, height, width, min_pixels=None, max_pixels=None): |
|
"""dummy""" |
|
actual_min_pixels = min_pixels if min_pixels is not None else self.min_pixels |
|
actual_max_pixels = max_pixels if max_pixels is not None else self.max_pixels |
|
resized_height, resized_width = smart_resize( |
|
height, |
|
width, |
|
factor=self.patch_size * self.merge_size, |
|
min_pixels=actual_min_pixels, |
|
max_pixels=actual_max_pixels, |
|
) |
|
return (resized_height, resized_width), ( |
|
resized_height // self.patch_size, |
|
resized_width // self.patch_size, |
|
) |
|
|
|
def _preprocess( |
|
self, |
|
images: Union[ImageInput, VideoInput], |
|
do_resize: bool = True, |
|
resample: PILImageResampling = None, |
|
do_rescale: bool = True, |
|
rescale_factor: float = 1 / 255, |
|
do_normalize: bool = True, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
do_convert_rgb: bool = False, |
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
predetermined_grid_thw=None, |
|
): |
|
""" |
|
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. |
|
|
|
Args: |
|
images (`ImageInput` or `VideoInput`): |
|
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. |
|
If pixel values range from 0 to 1, set `do_rescale=False`. |
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
|
Whether to resize the image. |
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`): |
|
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. |
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
|
Whether to rescale the image. |
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): |
|
Scale factor to use if rescaling the image. |
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
|
Whether to normalize the image. |
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
|
Mean to use if normalizing the image. |
|
Can be a float or a list of floats corresponding to the number of channels in the image. |
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
|
Standard deviation to use if normalizing the image. |
|
Can be a float or a list of floats corresponding to the number of channels in the image. |
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
|
Whether to convert the image to RGB. |
|
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): |
|
The channel dimension format for the output image. Can be one of: |
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
|
- Unset: Use the channel dimension format of the input image. |
|
input_data_format (`ChannelDimension` or `str`, *optional*): |
|
The channel dimension format for the input image. Can be one of: |
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
|
""" |
|
images = make_list_of_images(images) |
|
|
|
if do_convert_rgb: |
|
images = [convert_to_rgb(image) for image in images] |
|
|
|
|
|
images = [to_numpy_array(image) for image in images] |
|
|
|
if is_scaled_image(images[0]) and do_rescale: |
|
logger.warning_once( |
|
"It looks like you are trying to rescale already rescaled images. If the input" |
|
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." |
|
) |
|
if input_data_format is None: |
|
|
|
input_data_format = infer_channel_dimension_format(images[0]) |
|
|
|
height, width = get_image_size(images[0], channel_dim=input_data_format) |
|
resized_height, resized_width = height, width |
|
processed_images = [] |
|
|
|
if predetermined_grid_thw is not None: |
|
assert len(predetermined_grid_thw) == len( |
|
images |
|
), f"len(predetermined_grid_thw) {len(predetermined_grid_thw)} == len(images) {len(images)}" |
|
|
|
for img_idx, image in enumerate(images): |
|
if do_resize: |
|
if predetermined_grid_thw is not None: |
|
(resized_height, resized_width) = predetermined_grid_thw[img_idx] |
|
resized_height *= self.patch_size |
|
resized_width *= self.patch_size |
|
else: |
|
resized_height, resized_width = smart_resize( |
|
height, |
|
width, |
|
factor=self.patch_size * self.merge_size, |
|
min_pixels=self.min_pixels, |
|
max_pixels=self.max_pixels, |
|
) |
|
|
|
image = resize( |
|
image, |
|
size=(resized_height, resized_width), |
|
resample=resample, |
|
data_format=input_data_format, |
|
) |
|
if do_rescale: |
|
image = rescale( |
|
image, scale=rescale_factor, data_format=input_data_format |
|
) |
|
|
|
if do_normalize: |
|
image = normalize( |
|
image=image, |
|
mean=image_mean, |
|
std=image_std, |
|
data_format=input_data_format, |
|
) |
|
|
|
image = to_channel_dimension_format( |
|
image, data_format, input_channel_dim=input_data_format |
|
) |
|
|
|
processed_images.append(image) |
|
patches = np.array(processed_images) |
|
if data_format == ChannelDimension.LAST: |
|
patches = patches.transpose([0, 3, 1, 2]) |
|
|
|
channel = patches.shape[1] |
|
grid_t = patches.shape[0] |
|
grid_h, grid_w = ( |
|
resized_height // self.patch_size, |
|
resized_width // self.patch_size, |
|
) |
|
patches = patches.reshape( |
|
[ |
|
grid_t, |
|
channel, |
|
grid_h // self.merge_size, |
|
self.merge_size, |
|
self.patch_size, |
|
grid_w // self.merge_size, |
|
self.merge_size, |
|
self.patch_size, |
|
] |
|
) |
|
|
|
patches = patches.transpose([0, 2, 5, 3, 6, 1, 4, 7]) |
|
|
|
flatten_patches = patches.reshape( |
|
[grid_t * grid_h * grid_w, channel * self.patch_size * self.patch_size] |
|
) |
|
|
|
return flatten_patches, (grid_t, grid_h, grid_w) |
|
|
|
def preprocess( |
|
self, |
|
images: ImageInput, |
|
videos: VideoInput = None, |
|
do_resize: bool = True, |
|
size: Optional[Union[int, List[int]]] = None, |
|
resample: PILImageResampling = None, |
|
do_rescale: bool = True, |
|
rescale_factor: float = 1 / 255, |
|
do_normalize: bool = True, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
do_convert_rgb: bool = False, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
predetermined_grid_thw=None, |
|
): |
|
""" |
|
Args: |
|
images (`ImageInput`): |
|
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If |
|
passing in images with pixel values between 0 and 1, set `do_rescale=False`. |
|
videos (`VideoInput`): |
|
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If |
|
passing in videos with pixel values between 0 and 1, set `do_rescale=False`. |
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
|
Whether to resize the image. |
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`): |
|
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with |
|
the longest edge resized to keep the input aspect ratio. |
|
resample (`int`, *optional*, defaults to `self.resample`): |
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only |
|
has an effect if `do_resize` is set to `True`. |
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
|
Whether to rescale the image. |
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): |
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`. |
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
|
Whether to normalize the image. |
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. |
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to |
|
`True`. |
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
|
Whether to convert the image to RGB. |
|
return_tensors (`str` or `TensorType`, *optional*): |
|
The type of tensors to return. Can be one of: |
|
- Unset: Return a list of `np.ndarray`. |
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. |
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. |
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): |
|
The channel dimension format for the output image. Can be one of: |
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
|
- Unset: Use the channel dimension format of the input image. |
|
input_data_format (`ChannelDimension` or `str`, *optional*): |
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred |
|
from the input image. Can be one of: |
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
|
|
|
""" |
|
do_resize = do_resize if do_resize is not None else self.do_resize |
|
size = size if size is not None else self.size |
|
resample = resample if resample is not None else self.resample |
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale |
|
rescale_factor = ( |
|
rescale_factor if rescale_factor is not None else self.rescale_factor |
|
) |
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize |
|
image_mean = image_mean if image_mean is not None else self.image_mean |
|
image_std = image_std if image_std is not None else self.image_std |
|
do_convert_rgb = ( |
|
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb |
|
) |
|
|
|
if images is not None: |
|
images = make_batched_images(images) |
|
|
|
if images is not None and not valid_images(images): |
|
raise ValueError( |
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
|
"torch.Tensor." |
|
) |
|
|
|
data = {} |
|
if images is not None: |
|
pixel_values, vision_grid_thws = [], [] |
|
for img_idx, image in enumerate(images): |
|
if predetermined_grid_thw is not None: |
|
predetermined_grid_thw_one = [predetermined_grid_thw[img_idx]] |
|
else: |
|
predetermined_grid_thw_one = None |
|
patches, image_grid_thw = self._preprocess( |
|
image, |
|
do_resize=do_resize, |
|
resample=resample, |
|
do_rescale=do_rescale, |
|
rescale_factor=rescale_factor, |
|
do_normalize=do_normalize, |
|
image_mean=image_mean, |
|
image_std=image_std, |
|
data_format=data_format, |
|
do_convert_rgb=do_convert_rgb, |
|
input_data_format=input_data_format, |
|
predetermined_grid_thw=predetermined_grid_thw_one, |
|
) |
|
pixel_values.extend(patches) |
|
vision_grid_thws.append(image_grid_thw) |
|
pixel_values = np.array(pixel_values) |
|
vision_grid_thws = np.array(vision_grid_thws) |
|
data.update( |
|
{"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} |
|
) |
|
|
|
if videos is not None: |
|
videos = make_batched_videos(videos) |
|
pixel_values, vision_grid_thws = [], [] |
|
for images in videos: |
|
patches, video_grid_thw = self._preprocess( |
|
images, |
|
do_resize=do_resize, |
|
resample=resample, |
|
do_rescale=do_rescale, |
|
rescale_factor=rescale_factor, |
|
do_normalize=do_normalize, |
|
image_mean=image_mean, |
|
image_std=image_std, |
|
data_format=data_format, |
|
do_convert_rgb=do_convert_rgb, |
|
input_data_format=input_data_format, |
|
predetermined_grid_thw=predetermined_grid_thw, |
|
) |
|
pixel_values.extend(patches) |
|
vision_grid_thws.append(video_grid_thw) |
|
pixel_values = np.array(pixel_values) |
|
vision_grid_thws = np.array(vision_grid_thws) |
|
|
|
data.update( |
|
{ |
|
"pixel_values_videos": pixel_values, |
|
"video_grid_thw": vision_grid_thws, |
|
} |
|
) |
|
|
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
|
|
class Ernie4_5_VLTokenizer(PreTrainedTokenizer): |
|
""" |
|
Ernie4_5_VLTokenizer |
|
""" |
|
|
|
vocab_files_names = { |
|
"vocab_file": "tokenizer.model", |
|
} |
|
|
|
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"] |
|
|
|
padding_side = "right" |
|
|
|
def __init__( |
|
self, |
|
vocab_file, |
|
bos_token="<s>", |
|
cls_token="<cls>", |
|
eos_token="</s>", |
|
mask_token="<mask:0>", |
|
pad_token="<pad>", |
|
sep_token="<sep>", |
|
unk_token="<unk>", |
|
additional_special_tokens=None, |
|
**kwargs, |
|
): |
|
""" |
|
Initialize the Ernie4_5_VLTokenizer |
|
|
|
Args: |
|
vocab_file (str): Path to the tokenizer vocabulary model. |
|
bos_token (str, optional): The beginning of sequence token. Defaults to `"<s>"`. |
|
cls_token (str, optional): The classifier token. Defaults to `"<cls>"`. |
|
eos_token (str, optional): The end of sequence token. Defaults to `"</s>"`. |
|
mask_token (str, optional): The masking token. Defaults to `"<mask:0>"`. |
|
pad_token (str, optional): The padding token. Defaults to `"<pad>"`. |
|
sep_token (str, optional): The separation token. Defaults to `"<sep>"`. |
|
unk_token (str, optional): The unknown tokens symbol. Defaults to `"<unk>"`. |
|
additional_special_tokens (List[str], optional): Additional special tokens to use. |
|
Defaults to `["<mask:1>", "<mask:7>"]`. |
|
**kwargs (dict): Additional keyword arguments passed along to the superclass. |
|
""" |
|
|
|
|
|
self.vocab_file = vocab_file |
|
|
|
self.sp_model = spm.SentencePieceProcessor() |
|
|
|
self.sp_model.Load(vocab_file) |
|
|
|
|
|
if additional_special_tokens is None: |
|
additional_special_tokens = ["<mask:1>", "<mask:7>"] |
|
super().__init__( |
|
bos_token=bos_token, |
|
cls_token=cls_token, |
|
eos_token=eos_token, |
|
mask_token=mask_token, |
|
pad_token=pad_token, |
|
sep_token=sep_token, |
|
unk_token=unk_token, |
|
additional_special_tokens=additional_special_tokens, |
|
**kwargs, |
|
) |
|
|
|
@property |
|
def space_token(self): |
|
"""Return the space token""" |
|
return "<mask:1>" |
|
|
|
@property |
|
def space_token_id(self): |
|
"""Return the ID of the space token""" |
|
return self.sp_model.piece_to_id("<mask:1>") |
|
|
|
@property |
|
def gend_token(self): |
|
"""Return the gender token""" |
|
return "<mask:7>" |
|
|
|
@property |
|
def gend_token_id(self): |
|
"""Return the ID of the gender token""" |
|
return self.sp_model.piece_to_id("<mask:7>") |
|
|
|
@property |
|
def im_start_id(self): |
|
"""Return the ID of the image start token""" |
|
return self.sp_model.piece_to_id("<|im_start|>") |
|
|
|
@property |
|
def im_end_id(self): |
|
"""Return the ID of the image end token""" |
|
return self.sp_model.piece_to_id("<|im_end|>") |
|
|
|
@property |
|
def vocab_size(self): |
|
"""Return the size of the vocabulary""" |
|
return self.sp_model.vocab_size() |
|
|
|
def get_vocab(self): |
|
"""Return the vocabulary as a dictionary mapping tokens to IDs""" |
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} |
|
vocab.update(self.added_tokens_encoder) |
|
return vocab |
|
|
|
def _tokenize(self, text): |
|
"""Tokenize the input text into pieces""" |
|
return self.sp_model.encode_as_pieces(text) |
|
|
|
def _convert_token_to_id(self, token): |
|
"""Convert a token to its corresponding ID""" |
|
return self.sp_model.piece_to_id(token) |
|
|
|
def _convert_id_to_token(self, id): |
|
"""Convert an ID to its corresponding token""" |
|
return self.sp_model.id_to_piece(id) |
|
|
|
def convert_tokens_to_string(self, tokens): |
|
"""Convert a sequence of tokens back to a string""" |
|
current_sub_tokens = [] |
|
out_string = "" |
|
|
|
for token in tokens: |
|
|
|
if token in self.all_special_tokens: |
|
out_string += self.sp_model.decode(current_sub_tokens) + token |
|
current_sub_tokens = [] |
|
else: |
|
current_sub_tokens.append(token) |
|
|
|
|
|
out_string += self.sp_model.decode(current_sub_tokens) |
|
return out_string |
|
|
|
def prepare_for_model(self, *args, **kwargs): |
|
"""Prepare the tokenized inputs for the model""" |
|
|
|
if "add_special_tokens" in kwargs: |
|
kwargs.pop("add_special_tokens") |
|
return super().prepare_for_model(*args, **kwargs) |
|
|
|
def save_vocabulary( |
|
self, save_directory, filename_prefix: Optional[str] = None |
|
) -> Tuple[str]: |
|
""" |
|
Save the vocabulary and special tokens file to a directory. |
|
|
|
Args: |
|
save_directory (`str`): The directory to save the vocabulary to |
|
filename_prefix (`str`, optional): Prefix to add to the filename |
|
|
|
Returns: |
|
`Tuple(str)`: Paths to the saved files |
|
""" |
|
if not os.path.isdir(save_directory): |
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory") |
|
return |
|
|
|
|
|
out_vocab_file = os.path.join( |
|
save_directory, |
|
(filename_prefix + "-" if filename_prefix else "") |
|
+ self.vocab_files_names["vocab_file"], |
|
) |
|
|
|
|
|
if os.path.abspath(self.vocab_file) != os.path.abspath( |
|
out_vocab_file |
|
) and os.path.isfile(self.vocab_file): |
|
copyfile(self.vocab_file, out_vocab_file) |
|
elif not os.path.isfile(self.vocab_file): |
|
with open(out_vocab_file, "wb") as fi: |
|
content_spiece_model = self.sp_model.serialized_model_proto() |
|
fi.write(content_spiece_model) |
|
|
|
return (out_vocab_file,) |
|
|
|
def _decode(self, *args, **kwargs): |
|
"""Decode token_id back to text""" |
|
|
|
kwargs.pop("clean_up_tokenization_spaces", None) |
|
kwargs.pop("spaces_between_special_tokens", None) |
|
|
|
|
|
return super()._decode( |
|
*args, |
|
**kwargs, |
|
clean_up_tokenization_spaces=False, |
|
spaces_between_special_tokens=False, |
|
) |
|
|
|
def _pad( |
|
self, |
|
encoded_inputs: Dict, |
|
max_length: Optional[int] = None, |
|
padding_strategy=PaddingStrategy.DO_NOT_PAD, |
|
pad_to_multiple_of: Optional[int] = None, |
|
return_attention_mask: Optional[bool] = None, |
|
) -> dict: |
|
"""Pad the encoded inputs to the specified length""" |
|
if return_attention_mask is None: |
|
return_attention_mask = "attention_mask" in self.model_input_names |
|
if return_attention_mask: |
|
required_input = encoded_inputs[self.model_input_names[0]] |
|
if padding_strategy == PaddingStrategy.LONGEST: |
|
max_length = len(required_input) |
|
|
|
|
|
if ( |
|
max_length is not None |
|
and pad_to_multiple_of is not None |
|
and (max_length % pad_to_multiple_of != 0) |
|
): |
|
max_length = ( |
|
(max_length // pad_to_multiple_of) + 1 |
|
) * pad_to_multiple_of |
|
|
|
|
|
needs_to_be_padded = ( |
|
padding_strategy != PaddingStrategy.DO_NOT_PAD |
|
and len(required_input) != max_length |
|
) |
|
|
|
|
|
if ( |
|
"attention_mask" in encoded_inputs |
|
and encoded_inputs["attention_mask"] is not None |
|
): |
|
attention_mask = encoded_inputs.pop("attention_mask") |
|
if isinstance(attention_mask, torch.Tensor): |
|
attention_mask = attention_mask.numpy() |
|
elif isinstance(attention_mask, list): |
|
attention_mask = np.array(attention_mask) |
|
elif not isinstance(attention_mask, np.ndarray): |
|
raise ValueError( |
|
f"Unexpected type {type(attention_mask)} of attention_mask, " |
|
) |
|
else: |
|
|
|
attention_mask = np.tril( |
|
np.ones((len(required_input), len(required_input)), dtype=np.int64) |
|
) |
|
attention_mask = np.expand_dims(attention_mask, axis=0) |
|
|
|
|
|
if needs_to_be_padded: |
|
difference = max_length - len(required_input) |
|
if self.padding_side == "right": |
|
if attention_mask.ndim == 1: |
|
pad_width = [(0, difference)] |
|
else: |
|
pad_width = [(0, 0), (0, difference), (0, difference)] |
|
elif self.padding_side == "left": |
|
if attention_mask.ndim == 1: |
|
pad_width = [(difference, 0)] |
|
else: |
|
pad_width = [(0, 0), (difference, 0), (difference, 0)] |
|
else: |
|
raise ValueError( |
|
"Invalid padding strategy:" + str(self.padding_side) |
|
) |
|
|
|
attention_mask = np.pad( |
|
attention_mask, |
|
pad_width=pad_width, |
|
mode="constant", |
|
constant_values=0, |
|
) |
|
|
|
|
|
encoded_inputs = super()._pad( |
|
encoded_inputs, |
|
max_length, |
|
padding_strategy=padding_strategy, |
|
pad_to_multiple_of=pad_to_multiple_of, |
|
return_attention_mask=False, |
|
) |
|
|
|
|
|
if return_attention_mask: |
|
encoded_inputs["attention_mask"] = attention_mask.tolist() |
|
|
|
return encoded_inputs |
|
|
|
|
|
RAW_VIDEO_DIR = "./download_tmp/raw_video/" |
|
RAW_IMAGE_DIR = "./download_tmp/raw_images/" |
|
EXTRACTED_FRAME_DIR = "./download_tmp/extracted_frames/" |
|
TMP_DIR = "./download_tmp/upload_tmp/" |
|
|
|
FONT_PATH = os.path.join(Path(__file__).parent.absolute(), "Roboto-Regular.ttf") |
|
|
|
|
|
def is_gif(data: bytes) -> bool: |
|
""" |
|
check if a bytes is a gif based on the magic head |
|
""" |
|
return data[:6] in (b"GIF87a", b"GIF89a") |
|
|
|
|
|
class VideoReaderWrapper(decord.VideoReader): |
|
""" |
|
Solving memory leak bug |
|
|
|
https://github.com/dmlc/decord/issues/208 |
|
""" |
|
|
|
def __init__(self, video_path, *args, **kwargs): |
|
with ntf(delete=True, suffix=".gif") as gif_file: |
|
gif_input = None |
|
self.original_file = None |
|
if isinstance(video_path, str): |
|
self.original_file = video_path |
|
if video_path.lower().endswith(".gif"): |
|
gif_input = video_path |
|
elif isinstance(video_path, bytes): |
|
if is_gif(video_path): |
|
gif_file.write(video_path) |
|
gif_input = gif_file.name |
|
elif isinstance(video_path, io.BytesIO): |
|
video_path.seek(0) |
|
tmp_bytes = video_path.read() |
|
video_path.seek(0) |
|
if is_gif(tmp_bytes): |
|
gif_file.write(tmp_bytes) |
|
gif_input = gif_file.name |
|
|
|
if gif_input is not None: |
|
clip = mp.VideoFileClip(gif_input) |
|
mp4_file = ntf(delete=False, suffix=".mp4") |
|
clip.write_videofile(mp4_file.name, verbose=False, logger=None) |
|
clip.close() |
|
video_path = mp4_file.name |
|
self.original_file = video_path |
|
|
|
super().__init__(video_path, *args, **kwargs) |
|
self.seek(0) |
|
|
|
def __getitem__(self, key): |
|
frames = super().__getitem__(key) |
|
self.seek(0) |
|
return frames |
|
|
|
def __del__(self): |
|
if self.original_file and os.path.exists(self.original_file): |
|
os.remove(self.original_file) |
|
|
|
|
|
def get_filename(url=None): |
|
""" |
|
Get Filename |
|
""" |
|
if url is None: |
|
return str(uuid.uuid4()).replace("-", "") |
|
t = datetime.datetime.now() |
|
if not isinstance(url, bytes): |
|
url = url.encode("utf-8") |
|
|
|
md5_hash = hashlib.md5(url).hexdigest() |
|
pid = os.getpid() |
|
tid = threading.get_ident() |
|
|
|
|
|
image_filname = f"{t.year}-{t.month:02d}-{t.day:02d}-{pid}-{tid}-{md5_hash}" |
|
return image_filname |
|
|
|
|
|
def file_download(url, download_dir, save_to_disk=False, retry=0, retry_interval=3): |
|
""" |
|
Description: Download url, if url is PIL, return directly |
|
Args: |
|
url(str, PIL): http/local path/io.Bytes, note that io.Bytes is the image byte stream |
|
download_path: when save_to_disk=True, return the saved address |
|
save_to_disk: whether to save in the local path |
|
""" |
|
|
|
if isinstance(url, Image.Image): |
|
return url |
|
elif isinstance(url, VideoReaderWrapper): |
|
return url |
|
elif url.startswith("http"): |
|
response = requests.get(url) |
|
bytes_data = response.content |
|
elif os.path.isfile(url): |
|
if save_to_disk: |
|
return url |
|
bytes_data = open(url, "rb").read() |
|
else: |
|
bytes_data = base64.b64decode(url) |
|
if not save_to_disk: |
|
return bytes_data |
|
|
|
download_path = os.path.join(download_dir, get_filename(url)) |
|
Path(download_path).parent.mkdir(parents=True, exist_ok=True) |
|
with open(download_path, "wb") as f: |
|
f.write(bytes_data) |
|
return download_path |
|
|
|
|
|
def get_downloadable( |
|
url, download_dir=RAW_VIDEO_DIR, save_to_disk=False, retry=0, retry_interval=3 |
|
): |
|
"""download video and store it in the disk |
|
|
|
return downloaded **path** if save_to_disk is set to true |
|
return downloaded **bytes** if save_to_disk is set to false |
|
""" |
|
|
|
if not os.path.exists(download_dir): |
|
os.makedirs(download_dir) |
|
downloaded_path = file_download( |
|
url, |
|
download_dir, |
|
save_to_disk=save_to_disk, |
|
retry=retry, |
|
retry_interval=retry_interval, |
|
) |
|
return downloaded_path |
|
|
|
|
|
def get_downloadable_image( |
|
download_path, need_exif_info, retry_max_time=0, retry_interval=3 |
|
): |
|
""" |
|
Get downloadable with exif info and image processing |
|
""" |
|
|
|
def get_image_exif(image): |
|
exif_data = image._getexif() |
|
exif_info = {} |
|
if exif_data is not None: |
|
for tag, value in exif_data.items(): |
|
tag_name = TAGS.get(tag, tag) |
|
exif_info[tag_name] = value.strip() |
|
return exif_info |
|
|
|
def has_transparent_background(img): |
|
"""has_transparent_background""" |
|
if img.mode in ("RGBA", "LA") or ( |
|
img.mode == "P" and "transparency" in img.info |
|
): |
|
|
|
alpha = img.convert("RGBA").split()[-1] |
|
if alpha.getextrema()[0] < 255: |
|
return True |
|
return False |
|
|
|
def add_white_background(img): |
|
""" |
|
Add a white background to a transparent background image |
|
""" |
|
if img.mode != "RGBA": |
|
img = img.convert("RGBA") |
|
|
|
img_white_background = Image.new("RGBA", img.size, (255, 255, 255)) |
|
|
|
|
|
img_white_background.paste(img, (0, 0), img) |
|
|
|
return img_white_background |
|
|
|
def change_I16_to_L(img): |
|
""" |
|
Convert image from I;16 mode to L mode |
|
""" |
|
|
|
|
|
return img.point(lambda i: i * (1 / 256)).convert("L") |
|
|
|
image = get_downloadable( |
|
download_path, |
|
save_to_disk=False, |
|
retry=retry_max_time, |
|
retry_interval=retry_interval, |
|
) |
|
if isinstance(image, Image.Image): |
|
pil_image = image |
|
else: |
|
pil_image = Image.open(io.BytesIO(image)) |
|
if need_exif_info: |
|
try: |
|
exif_info = get_image_exif(pil_image) |
|
except Exception as why: |
|
exif_info = {} |
|
else: |
|
exif_info = {} |
|
|
|
try: |
|
if pil_image.mode == "I;16": |
|
pil_image = change_I16_to_L(pil_image) |
|
if has_transparent_background(pil_image): |
|
pil_image = add_white_background(pil_image) |
|
except Exception as e: |
|
pass |
|
|
|
return pil_image.convert("RGB"), exif_info |
|
|
|
|
|
def read_video_decord(video_path, save_to_disk): |
|
"""get reader and meta by decord""" |
|
video_path = get_downloadable(video_path, save_to_disk=save_to_disk) |
|
if isinstance(video_path, VideoReaderWrapper): |
|
video_reader = video_path |
|
else: |
|
if isinstance(video_path, bytes): |
|
video_path = io.BytesIO(video_path) |
|
video_reader = VideoReaderWrapper(video_path, num_threads=1) |
|
vlen = len(video_reader) |
|
fps = video_reader.get_avg_fps() |
|
duration = vlen / float(fps) |
|
|
|
video_meta = {"fps": fps, "duration": duration, "num_of_frame": vlen} |
|
|
|
return video_reader, video_meta, video_path |
|
|
|
|
|
def get_frame_indices( |
|
vlen, |
|
target_frames=-1, |
|
target_fps=-1, |
|
frames_sample="middle", |
|
fix_start=None, |
|
input_fps=-1, |
|
): |
|
"""get_frame_indices""" |
|
assert frames_sample in ["rand", "middle", "leading"] |
|
if target_frames > 0: |
|
assert target_fps <= 0, "target_fps must be negative if target_frames is given." |
|
if target_frames > vlen: |
|
acc_samples = vlen |
|
logger.info( |
|
f"target_frames={target_frames} is larger than video length {vlen}, " |
|
f"will sample {acc_samples} frames." |
|
) |
|
else: |
|
acc_samples = target_frames |
|
logger.debug( |
|
f"sampling at target_frames={target_frames}, frames_sample={frames_sample}" |
|
) |
|
|
|
|
|
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) |
|
ranges = [] |
|
for idx, interv in enumerate(intervals[:-1]): |
|
ranges.append((interv, intervals[idx + 1] - 1)) |
|
if frames_sample == "rand": |
|
try: |
|
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] |
|
except Exception as e: |
|
frame_indices = np.random.permutation(vlen)[:acc_samples] |
|
frame_indices.sort() |
|
frame_indices = list(frame_indices) |
|
elif fix_start is not None: |
|
frame_indices = [x[0] + fix_start for x in ranges] |
|
elif frames_sample == "leading": |
|
frame_indices = [x[0] for x in ranges] |
|
elif frames_sample == "middle": |
|
frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
|
else: |
|
raise NotImplementedError |
|
|
|
elif target_fps > 0: |
|
assert ( |
|
target_frames <= 0 |
|
), "target_frames must be negative if target_fps is given." |
|
assert input_fps > 0, "input_fps must be provided if target_fps is given." |
|
logger.info(f"sampling at fps={target_fps}, frames_sample={frames_sample}") |
|
duration = float(vlen) / input_fps |
|
delta = ( |
|
1 / target_fps |
|
) |
|
if frames_sample == "middle": |
|
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
|
elif frames_sample == "leading": |
|
frame_seconds = np.arange(0, duration, delta) |
|
if frames_sample == "rand": |
|
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
|
rand_offset = np.random.rand(*(frame_seconds.shape)) - 0.5 |
|
frame_seconds += rand_offset * delta |
|
frame_indices = np.around(frame_seconds * input_fps).astype(int) |
|
frame_indices = [e for e in frame_indices if e < vlen] |
|
|
|
else: |
|
raise ValueError( |
|
"Must provide either positive target_fps or positive target_frames." |
|
) |
|
|
|
return frame_indices |
|
|
|
|
|
def read_frames_decord( |
|
video_path, |
|
video_reader, |
|
video_meta, |
|
target_frames=-1, |
|
target_fps=-1, |
|
frames_sample="middle", |
|
fix_start=None, |
|
save_to_disk=False, |
|
cache_dir=EXTRACTED_FRAME_DIR, |
|
frame_indices=None, |
|
tol=10, |
|
): |
|
"""get frames by decord""" |
|
|
|
if frame_indices is None: |
|
frame_indices = get_frame_indices( |
|
video_meta["num_of_frame"], |
|
target_frames=target_frames, |
|
target_fps=target_fps, |
|
frames_sample=frames_sample, |
|
fix_start=fix_start, |
|
input_fps=video_meta["fps"], |
|
) |
|
|
|
frames = [] |
|
for frame_indice_index in range(0, len(frame_indices)): |
|
frame_indice = frame_indices[frame_indice_index] |
|
try: |
|
frames.append(video_reader[frame_indice].asnumpy()) |
|
except Exception as e: |
|
logger.debug(f"encounter error when get frame: {frame_indice}, error: {e}") |
|
previous_counter = 1 |
|
later_counter = 1 |
|
previous_after_flag = True |
|
if frame_indice == 0 or frame_indice == len(video_reader) - 1: |
|
cur_tol = tol * 2 |
|
else: |
|
cur_tol = tol |
|
while previous_counter < cur_tol or later_counter < cur_tol: |
|
if previous_after_flag: |
|
if frame_indice - previous_counter < 0: |
|
previous_counter += 1 |
|
previous_after_flag = not previous_after_flag |
|
continue |
|
try: |
|
frames.append( |
|
video_reader[frame_indice - previous_counter].asnumpy() |
|
) |
|
logger.info( |
|
f"replace {frame_indice}-th frame with {frame_indice-previous_counter}-th frame" |
|
) |
|
frame_indices[frame_indice_index] = ( |
|
frame_indice - previous_counter |
|
) |
|
break |
|
except Exception as e: |
|
previous_counter += 1 |
|
else: |
|
if frame_indice + later_counter >= len(video_reader): |
|
later_counter += 1 |
|
previous_after_flag = not previous_after_flag |
|
continue |
|
try: |
|
frames.append( |
|
video_reader[frame_indice + later_counter].asnumpy() |
|
) |
|
logger.info( |
|
f"replace {frame_indice}-th frame with {frame_indice+later_counter}-th frame" |
|
) |
|
frame_indices[frame_indice_index] = frame_indice + later_counter |
|
break |
|
except Exception as e: |
|
later_counter += 1 |
|
previous_after_flag = not previous_after_flag |
|
|
|
frames = np.stack(frames, axis=0) |
|
assert len(frames) == len( |
|
frame_indices |
|
), f"len(frames): {len(frames)} != len(frame_indices): {len(frame_indices)}" |
|
|
|
ret = [] |
|
|
|
url_sha1 = get_filename() |
|
for idx, frame in enumerate(frames): |
|
tmp = Image.fromarray(frame, "RGB") |
|
if save_to_disk: |
|
save_path = os.path.join(cache_dir, f"{url_sha1}", f"{idx}.png") |
|
if not os.path.exists(os.path.dirname(save_path)): |
|
os.makedirs(os.path.dirname(save_path)) |
|
tmp.save(save_path) |
|
tmp = save_path |
|
ret.append(tmp) |
|
|
|
time_stamps = [ |
|
frame_idx * video_meta["duration"] / video_meta["num_of_frame"] |
|
for frame_idx in frame_indices |
|
] |
|
|
|
return ret, frame_indices, time_stamps |
|
|
|
|
|
def render_single_image_with_timestamp( |
|
image: Image, number: str, rate: float, font_path: str = FONT_PATH |
|
): |
|
""" |
|
Function: Renders a timestamp to the image of pil.image |
|
The timestamp size is the rate of min(width, height) |
|
The font color is black, the outline is white, and the outline size is 10% of the font |
|
Returns an Image object |
|
""" |
|
draw = ImageDraw.Draw(image) |
|
width, height = image.size |
|
font_size = int(min(width, height) * rate) |
|
outline_size = int(font_size * 0.1) |
|
font = ImageFont.truetype(font_path, font_size) |
|
x = 0 |
|
y = 0 |
|
|
|
|
|
draw.text( |
|
(x, y), |
|
number, |
|
font=font, |
|
fill=(0, 0, 0), |
|
stroke_width=outline_size, |
|
stroke_fill=(255, 255, 255), |
|
) |
|
|
|
return image |
|
|
|
|
|
def timestamp_converting(time_stamp_in_seconds): |
|
""" |
|
convert timestamp format from seconds to hr:min:sec |
|
""" |
|
|
|
hours = 0 |
|
while time_stamp_in_seconds >= 3600: |
|
hours += 1 |
|
time_stamp_in_seconds -= 3600 |
|
|
|
mins = 0 |
|
while time_stamp_in_seconds >= 60: |
|
mins += 1 |
|
time_stamp_in_seconds -= 60 |
|
time_hours = f"{int(hours):02d}" |
|
time_mins = f"{int(mins):02d}" |
|
time_secs = f"{time_stamp_in_seconds:05.02f}" |
|
fi_time_stamp = time_hours + ":" + time_mins + ":" + time_secs |
|
|
|
return fi_time_stamp |
|
|
|
|
|
def render_frame_timestamp(frame, timestamp, font_rate=0.1): |
|
""" |
|
Function, given a frame, render the index in order |
|
Logic: render the index to the upper left corner of the image |
|
frame: frame, PIL.Image object |
|
timestamp: timestamp, in seconds |
|
font_rate: the ratio of font size to min(wi, hei) |
|
""" |
|
time_stamp = "time: " + timestamp_converting(timestamp) |
|
new_frame = render_single_image_with_timestamp(frame, time_stamp, font_rate) |
|
|
|
return new_frame |
|
|
|
|
|
IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} |
|
|
|
|
|
class Ernie_45T_VLProcessor(ProcessorMixin): |
|
""" |
|
Processes multimodal chat messages into model-ready inputs, |
|
handling text, images, and videos with 3D positional embeddings. |
|
""" |
|
|
|
attributes = ["image_processor", "tokenizer"] |
|
valid_kwargs = [ |
|
"chat_template", |
|
"spatial_conv_size", |
|
"temporal_conv_size", |
|
"image_min_pixels", |
|
"image_max_pixels", |
|
"video_min_pixels", |
|
"video_max_pixels", |
|
"video_target_frames", |
|
"video_frames_sample", |
|
"video_max_frames", |
|
"video_min_frames", |
|
"video_fps", |
|
] |
|
image_processor_class = "AutoImageProcessor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
CLS_TOKEN = "<|begin_of_sentence|>" |
|
SEP_TOKEN = "<|end_of_sentence|>" |
|
IMG_START = "<|IMAGE_START|>" |
|
IMG_END = "<|IMAGE_END|>" |
|
VID_START = "<|VIDEO_START|>" |
|
VID_END = "<|VIDEO_END|>" |
|
|
|
def __init__( |
|
self, |
|
image_processor=None, |
|
tokenizer=None, |
|
chat_template=None, |
|
spatial_conv_size: int = 2, |
|
temporal_conv_size: int = 2, |
|
image_min_pixels: int = 4 * 28 * 28, |
|
image_max_pixels: int = 6177 * 28 * 28, |
|
video_min_pixels: int = 299 * 28 * 28, |
|
video_max_pixels: int = 1196 * 28 * 28, |
|
video_target_frames: int = -1, |
|
video_frames_sample: str = "leading", |
|
video_max_frames: int = 180, |
|
video_min_frames: int = 16, |
|
video_fps: int = 2, |
|
**kwargs, |
|
): |
|
super().__init__(image_processor, tokenizer, chat_template=chat_template) |
|
self.tokenizer.ignored_index = -100 |
|
|
|
|
|
self.spatial_conv_size = spatial_conv_size |
|
self.temporal_conv_size = temporal_conv_size |
|
|
|
|
|
self.image_min_pixels = image_min_pixels |
|
self.image_max_pixels = image_max_pixels |
|
self.video_min_pixels = video_min_pixels |
|
self.video_max_pixels = video_max_pixels |
|
|
|
|
|
self.target_frames = video_target_frames |
|
self.frames_sample = video_frames_sample |
|
self.max_frames = video_max_frames |
|
self.min_frames = video_min_frames |
|
self.fps = video_fps |
|
|
|
|
|
self.cls_token = self.CLS_TOKEN |
|
self.sep_token = self.SEP_TOKEN |
|
self.image_start = self.IMG_START |
|
self.image_end = self.IMG_END |
|
self.video_start = self.VID_START |
|
self.video_end = self.VID_END |
|
self.image_patch_id = self.tokenizer.convert_tokens_to_ids( |
|
"<|IMAGE_PLACEHOLDER|>" |
|
) |
|
|
|
self.token_type_mapping = self._build_token_type_mapping() |
|
self.is_training = True |
|
self.role_prefixes = {"system": "", "user": "User: ", "bot": "Assistant: "} |
|
|
|
def _build_token_type_mapping(self) -> Dict[Any, int]: |
|
mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"]) |
|
for token in (self.IMG_START, self.IMG_END, self.VID_START, self.VID_END): |
|
mapping[token] = IDS_TYPE_FLAG["image"] |
|
mapping[self.image_patch_id] = IDS_TYPE_FLAG["image"] |
|
return mapping |
|
|
|
def train(self) -> None: |
|
"""Enable training mode (produces labels).""" |
|
self.is_training = True |
|
|
|
def eval(self) -> None: |
|
"""Enable evaluation mode (doesn't produce labels).""" |
|
self.is_training = False |
|
|
|
def _download_image( |
|
self, |
|
item: Dict, |
|
): |
|
"""Download image from url and resize it to the specified size.""" |
|
url_info = item.get("image_url", {}) |
|
url = url_info.get("url") |
|
w = url_info.get("image_width", None) |
|
h = url_info.get("image_height", None) |
|
data = get_downloadable(url, download_dir=RAW_IMAGE_DIR, save_to_disk=False) |
|
|
|
img = Image.open(io.BytesIO(data) if isinstance(data, bytes) else data) |
|
if w and h: |
|
img = img.resize((w, h)) |
|
return img |
|
|
|
def _download_video(self, item: Dict): |
|
"""Download video from url and resize it to the specified size.""" |
|
url_info = item.get("video_url", {}) |
|
url = url_info.get("url") |
|
|
|
frames = self._load_and_process_video(url, item) |
|
|
|
pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) |
|
return pixel_stack |
|
|
|
def process_vision_info(self, messages: List[Dict[str, Any]]): |
|
"""Preprocess messages into lists of text, images, and videos.""" |
|
images = [] |
|
videos = [] |
|
|
|
for msg in messages: |
|
content_items = msg.get("content") |
|
if not isinstance(content_items, list): |
|
content_items = [content_items] |
|
|
|
for item in content_items: |
|
if item.get("type") == "image_url": |
|
img = self._download_image(item) |
|
images.append(img) |
|
elif item.get("type") == "video_url": |
|
pixel_stack = self._download_video(item) |
|
videos.append(pixel_stack) |
|
|
|
return images, videos |
|
|
|
def __call__( |
|
self, |
|
text: List[str], |
|
images: List[Image.Image], |
|
videos: List[List[Image.Image]], |
|
**kwargs, |
|
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: |
|
""" |
|
Convert chat messages into model inputs. |
|
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. |
|
""" |
|
outputs = { |
|
"input_ids": [], |
|
"token_type_ids": [], |
|
"position_ids": [], |
|
"images": [], |
|
"grid_thw": [], |
|
"image_type_ids": [], |
|
"cur_position": 0, |
|
"pic_cnt": 0, |
|
"video_cnt": 0, |
|
} |
|
texts = text[0] |
|
|
|
new_video_seg = True |
|
for text_with_image in texts.split(self.VID_START + "<|video@placeholder|>" + self.VID_END): |
|
new_text_seg = True |
|
if not new_video_seg: |
|
self._add_video(videos[outputs["video_cnt"]], outputs) |
|
for text in text_with_image.split(self.IMG_START + "<|image@placeholder|>" + self.IMG_END): |
|
if not new_text_seg: |
|
self._add_image(images[outputs["pic_cnt"]], outputs) |
|
self._add_text(text, outputs) |
|
new_text_seg = False |
|
new_video_seg = False |
|
|
|
for key in ["cur_position", "pic_cnt", "video_cnt"]: |
|
outputs.pop(key, None) |
|
|
|
outputs = self._pack_outputs(outputs) |
|
for key in outputs.keys(): |
|
if isinstance(outputs[key], np.ndarray): |
|
if key in ["images", "grid_thw"]: |
|
outputs[key] = torch.tensor(np.array(outputs[key])) |
|
else: |
|
outputs[key] = torch.tensor(np.array([outputs[key]])) |
|
|
|
return BatchFeature(data=outputs) |
|
|
|
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: |
|
"""add special token to outputs""" |
|
token_id = ( |
|
token |
|
if isinstance(token, int) |
|
else self.tokenizer.convert_tokens_to_ids(token) |
|
) |
|
outputs["input_ids"].append(token_id) |
|
outputs["token_type_ids"].append(self.token_type_mapping[token]) |
|
pos = outputs["cur_position"] |
|
outputs["position_ids"].append([pos] * 3) |
|
outputs["cur_position"] += 1 |
|
|
|
def _add_text(self, text: str, outputs: Dict) -> None: |
|
"""add text to outputs""" |
|
tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) |
|
outputs["input_ids"].extend(tokens) |
|
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens)) |
|
|
|
start = outputs["cur_position"] |
|
for i in range(len(tokens)): |
|
outputs["position_ids"].append([start + i] * 3) |
|
outputs["cur_position"] += len(tokens) |
|
|
|
def _add_image(self, img: Image.Image, outputs: Dict) -> None: |
|
"""add image to outputs""" |
|
outputs["pic_cnt"] += 1 |
|
self._add_special_token(self.IMG_START, outputs) |
|
|
|
patches_h, patches_w = self.image_processor.get_smarted_resize( |
|
img.height, |
|
img.width, |
|
min_pixels=self.image_min_pixels, |
|
max_pixels=self.image_max_pixels, |
|
)[1] |
|
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2) |
|
|
|
outputs["input_ids"].extend([self.image_patch_id] * num_tokens) |
|
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) |
|
|
|
pos_ids = self._compute_3d_positions( |
|
1, patches_h, patches_w, outputs["cur_position"] |
|
) |
|
outputs["position_ids"].extend(pos_ids) |
|
outputs["cur_position"] = np.max(pos_ids) + 1 |
|
|
|
|
|
ret = self.image_processor.preprocess( |
|
images=[img.convert("RGB")], |
|
do_normalize=False, |
|
do_rescale=False, |
|
predetermined_grid_thw=np.array([[patches_h, patches_w]]), |
|
do_convert_rgb=True, |
|
input_data_format=ChannelDimension.LAST, |
|
) |
|
outputs["images"].append(ret["pixel_values"]) |
|
outputs["grid_thw"].append(ret["image_grid_thw"]) |
|
outputs["image_type_ids"].append(0) |
|
|
|
self._add_special_token(self.IMG_END, outputs) |
|
|
|
def _add_video( |
|
self, pixel_stack: List[np.ndarray], outputs: Dict |
|
) -> None: |
|
outputs["video_cnt"] += 1 |
|
self._add_special_token(self.VID_START, outputs) |
|
|
|
patches_h, patches_w = self.image_processor.get_smarted_resize( |
|
pixel_stack.shape[1], |
|
pixel_stack.shape[2], |
|
min_pixels=self.video_min_pixels, |
|
max_pixels=self.video_max_pixels, |
|
)[1] |
|
num_frames = pixel_stack.shape[0] |
|
num_tokens = (num_frames * patches_h * patches_w) // ( |
|
self.spatial_conv_size**2 * self.temporal_conv_size |
|
) |
|
|
|
ret = self.image_processor.preprocess( |
|
images=None, |
|
videos=pixel_stack, |
|
do_normalize=False, |
|
do_rescale=False, |
|
predetermined_grid_thw=np.array([[patches_h, patches_w]] * num_frames), |
|
do_convert_rgb=True, |
|
input_data_format=ChannelDimension.LAST, |
|
) |
|
outputs["images"].append(ret["pixel_values_videos"]) |
|
outputs["grid_thw"].append(ret["video_grid_thw"]) |
|
outputs["image_type_ids"].extend([1] * num_frames) |
|
|
|
outputs["input_ids"].extend([self.image_patch_id] * num_tokens) |
|
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) |
|
|
|
pos_ids = self._compute_3d_positions( |
|
num_frames, patches_h, patches_w, outputs["cur_position"] |
|
) |
|
outputs["position_ids"].extend(pos_ids) |
|
outputs["cur_position"] = np.max(pos_ids) + 1 |
|
|
|
self._add_special_token(self.VID_END, outputs) |
|
|
|
def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]: |
|
reader, meta, path = read_video_decord(url, save_to_disk=False) |
|
|
|
video_frame_args = dict() |
|
video_frame_args["fps"] = item.get("fps", self.fps) |
|
video_frame_args["min_frames"] = item.get("min_frames", self.min_frames) |
|
video_frame_args["max_frames"] = item.get("max_frames", self.max_frames) |
|
video_frame_args["target_frames"] = item.get( |
|
"target_frames", self.target_frames |
|
) |
|
video_frame_args["frames_sample"] = item.get( |
|
"frames_sample", self.frames_sample |
|
) |
|
|
|
video_frame_args = self._set_video_frame_args(video_frame_args, meta) |
|
|
|
frames_data, _, timestamps = read_frames_decord( |
|
path, |
|
reader, |
|
meta, |
|
target_frames=video_frame_args["target_frames"], |
|
target_fps=video_frame_args["fps"], |
|
frames_sample=video_frame_args["frames_sample"], |
|
save_to_disk=False, |
|
) |
|
|
|
frames: List[Image.Image] = [] |
|
for img_array, ts in zip(frames_data, timestamps): |
|
frames.append(render_frame_timestamp(img_array, ts)) |
|
|
|
if len(frames) % 2 != 0: |
|
frames.append(copy.deepcopy(frames[-1])) |
|
return frames |
|
|
|
def _set_video_frame_args(self, video_frame_args, video_meta): |
|
""" |
|
Set the final frame extraction parameters based on known parameters and priorities |
|
""" |
|
|
|
if video_frame_args["target_frames"] > 0: |
|
if video_frame_args["fps"] >= 0: |
|
raise ValueError("fps must be negative if target_frames is given") |
|
if ( |
|
video_frame_args["min_frames"] > 0 |
|
and video_frame_args["target_frames"] < video_frame_args["min_frames"] |
|
): |
|
raise ValueError("target_frames must be larger than min_frames") |
|
if ( |
|
video_frame_args["max_frames"] > 0 |
|
and video_frame_args["target_frames"] > video_frame_args["max_frames"] |
|
): |
|
raise ValueError("target_frames must be smaller than max_frames") |
|
else: |
|
if video_frame_args["fps"] < 0: |
|
raise ValueError( |
|
"Must provide either positive target_fps or positive target_frames." |
|
) |
|
|
|
frames_to_extract = int(video_meta["duration"] * video_frame_args["fps"]) |
|
|
|
if ( |
|
video_frame_args["min_frames"] > 0 |
|
and video_frame_args["max_frames"] > 0 |
|
and video_frame_args["min_frames"] > video_frame_args["max_frames"] |
|
): |
|
raise ValueError("min_frames must be smaller than max_frames") |
|
if ( |
|
video_frame_args["min_frames"] > 0 |
|
and frames_to_extract < video_frame_args["min_frames"] |
|
): |
|
video_frame_args["target_frames"] = video_frame_args["min_frames"] |
|
video_frame_args["fps"] = -1 |
|
if ( |
|
video_frame_args["max_frames"] > 0 |
|
and frames_to_extract > video_frame_args["max_frames"] |
|
): |
|
video_frame_args["target_frames"] = video_frame_args["max_frames"] |
|
video_frame_args["fps"] = -1 |
|
|
|
return video_frame_args |
|
|
|
def _compute_3d_positions( |
|
self, t: int, h: int, w: int, start_idx: int |
|
) -> List[List[int]]: |
|
|
|
t_eff = t // self.temporal_conv_size if t != 1 else 1 |
|
gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size |
|
time_idx = np.repeat(np.arange(t_eff), gh * gw) |
|
h_idx = np.tile(np.repeat(np.arange(gh), gw), t_eff) |
|
w_idx = np.tile(np.arange(gw), t_eff * gh) |
|
|
|
coords = list(zip(time_idx, h_idx, w_idx)) |
|
return [ |
|
[start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords |
|
] |
|
|
|
def _pack_outputs(self, outs: Dict) -> Dict[str, Any]: |
|
|
|
if not outs["images"]: |
|
outs["images"] = None |
|
outs["grid_thw"] = None |
|
outs["image_type_ids"] = None |
|
else: |
|
outs["images"] = np.vstack(outs["images"]) |
|
outs["grid_thw"] = np.vstack(outs["grid_thw"]) |
|
outs["image_type_ids"] = np.array(outs["image_type_ids"]) |
|
|
|
|
|
outs["input_ids"] = np.array(outs["input_ids"], dtype=np.int64) |
|
outs["token_type_ids"] = np.array(outs["token_type_ids"], dtype=np.int64) |
|
outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64) |
|
return outs |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to Ernie4_5_VLTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to Ernie4_5_VLTokenizer's [`~PreTrainedTokenizer.decode`]. |
|
Please refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@property |
|
def model_input_names(self): |
|
"""get model input names""" |
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
image_processor_input_names = self.image_processor.model_input_names |
|
return list(tokenizer_input_names) + list(image_processor_input_names) |
|
|
|
|
|
__all__ = ["Ernie_45T_VLImageProcessor", "Ernie4_5_VLTokenizer", "Ernie_45T_VLProcessor"] |