Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from pathlib import Path | |
from random import randint | |
from typing import Optional, Tuple | |
import numpy as np | |
import torch | |
from transformers import BartTokenizerFast | |
class Preprocessor: | |
encodec_base_path: Path | |
clap_base_path: Path | |
tokenizer: BartTokenizerFast = BartTokenizerFast.from_pretrained( | |
"facebook/bart-base" | |
) | |
max_length: int = 1024 | |
mcm_masking_prob: float = 0.15 | |
mcm_masking_span: int = 10 | |
label_pad_token_id: int = -100 | |
mask_token_id: int = 1024 | |
num_eval_captions: int = 5 | |
def __post_init__(self): | |
if isinstance(self.encodec_base_path, str): | |
self.encodec_base_path = Path(self.encodec_base_path) | |
if isinstance(self.clap_base_path, str): | |
self.clap_base_path = Path(self.clap_base_path) | |
if isinstance(self.tokenizer, str): | |
self.tokenizer = BartTokenizerFast.from_pretrained(self.tokenizer) | |
def preprocess_train(self, example): | |
path = example["file_path"] | |
encodec = np.load(self.encodec_base_path / path) | |
clap_embedding = np.load(self.clap_base_path / path) | |
encodec_mask = np.array( | |
[0, 0] + [1] * min(encodec.shape[0], self.max_length - 3) + [0] | |
) | |
attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype( | |
np.int64 | |
) | |
target_text = self.tokenizer(text_target=example["caption"]) | |
if encodec.shape[0] + 3 > self.max_length: | |
start = randint(0, encodec.shape[0] - self.max_length + 3) | |
encodec = encodec[start : start + self.max_length - 3] | |
mcm_labels = None | |
if self.mcm_masking_prob > 0: | |
num_rvq = encodec.shape[-1] | |
mcm_mask, _ = _compute_mask_indices( | |
encodec.T.shape, self.mcm_masking_prob, self.mcm_masking_span | |
) | |
mcm_mask = mcm_mask.T | |
mcm_labels = np.where(mcm_mask, encodec, self.label_pad_token_id) | |
mcm_labels = np.concatenate( | |
[ | |
np.ones((2, num_rvq), dtype=np.int64) * self.label_pad_token_id, | |
mcm_labels, | |
np.ones((1, num_rvq), dtype=np.int64) * self.label_pad_token_id, | |
], | |
axis=0, | |
) | |
encodec[mcm_mask] = self.mask_token_id | |
encodec = np.concatenate( | |
[ | |
np.ones((2, num_rvq), dtype=np.int64) * self.tokenizer.bos_token_id, | |
encodec, | |
np.ones((1, num_rvq), dtype=np.int64) * self.tokenizer.eos_token_id, | |
], | |
axis=0, | |
) | |
return { | |
"input_ids": encodec, | |
"clap_embedding": clap_embedding, | |
"encodec_mask": encodec_mask, | |
"attention_mask": attention_mask, | |
"mcm_labels": mcm_labels, | |
"labels": target_text["input_ids"], | |
} | |
def preprocess_eval(self, example): | |
path = example["file_path"] | |
encodec = np.load(self.encodec_base_path / path) | |
clap_embedding = np.load(self.clap_base_path / path) | |
attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype( | |
np.int64 | |
) | |
if encodec.shape[0] + 3 > self.max_length: | |
encodec = encodec[: self.max_length - 3] | |
captions = [] | |
for i in range(self.num_eval_captions): | |
captions.append(example[f"caption_{i+1}"]) | |
return { | |
"input_ids": encodec, | |
"attention_mask": attention_mask, | |
"clap": clap_embedding, | |
"captions": captions, | |
} | |
def _compute_mask_indices( | |
shape: Tuple[int, int], | |
mask_prob: float, | |
mask_length: int, | |
attention_mask: Optional[torch.LongTensor] = None, | |
min_masks: int = 0, | |
) -> np.ndarray: | |
""" | |
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for | |
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on | |
CPU as part of the preprocessing during training. | |
Args: | |
shape: The shape for which to compute masks. This should be of a tuple of size 2 where | |
the first element is the batch size and the second element is the length of the axis to span. | |
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of | |
independently generated mask spans of length `mask_length` is computed by | |
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the | |
actual percentage will be smaller. | |
mask_length: size of the mask | |
min_masks: minimum number of masked spans | |
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of | |
each batch dimension. | |
""" | |
batch_size, sequence_length = shape | |
if mask_length < 1: | |
raise ValueError("`mask_length` has to be bigger than 0.") | |
if mask_length > sequence_length: | |
raise ValueError( | |
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" | |
f" and `sequence_length`: {sequence_length}`" | |
) | |
# epsilon is used for probabilistic rounding | |
epsilon = np.random.rand(1).item() | |
def compute_num_masked_span(input_length): | |
"""Given input length, compute how many spans should be masked""" | |
num_masked_span = int(mask_prob * input_length / mask_length + epsilon) | |
num_masked_span = max(num_masked_span, min_masks) | |
# make sure num masked span <= sequence_length | |
if num_masked_span * mask_length > sequence_length: | |
num_masked_span = sequence_length // mask_length | |
# make sure num_masked span is also <= input_length - (mask_length - 1) | |
if input_length - (mask_length - 1) < num_masked_span: | |
num_masked_span = max(input_length - (mask_length - 1), 0) | |
return num_masked_span | |
# compute number of masked spans in batch | |
input_lengths = ( | |
attention_mask.sum(-1).detach().tolist() | |
if attention_mask is not None | |
else [sequence_length for _ in range(batch_size)] | |
) | |
# SpecAugment mask to fill | |
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) | |
spec_aug_mask_idxs = [] | |
max_num_masked_span = compute_num_masked_span(sequence_length) | |
if max_num_masked_span == 0: | |
return spec_aug_mask | |
for input_length in input_lengths: | |
# compute num of masked spans for this input | |
num_masked_span = compute_num_masked_span(input_length) | |
# get random indices to mask | |
spec_aug_mask_idx = np.random.choice( | |
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False | |
) | |
# pick first sampled index that will serve as a dummy index to pad vector | |
# to ensure same dimension for all batches due to probabilistic rounding | |
# Picking first sample just pads those vectors twice. | |
if len(spec_aug_mask_idx) == 0: | |
# this case can only happen if `input_length` is strictly smaller then | |
# `sequence_length` in which case the last token has to be a padding | |
# token which we can use as a dummy mask id | |
dummy_mask_idx = sequence_length - 1 | |
else: | |
dummy_mask_idx = spec_aug_mask_idx[0] | |
spec_aug_mask_idx = np.concatenate( | |
[ | |
spec_aug_mask_idx, | |
np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) | |
* dummy_mask_idx, | |
] | |
) | |
spec_aug_mask_idxs.append(spec_aug_mask_idx) | |
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) | |
# expand masked indices to masked spans | |
spec_aug_mask_idxs = np.broadcast_to( | |
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) | |
) | |
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape( | |
batch_size, max_num_masked_span * mask_length | |
) | |
# add offset to the starting indexes so that indexes now create a span | |
offsets = np.arange(mask_length)[None, None, :] | |
offsets = np.broadcast_to( | |
offsets, (batch_size, max_num_masked_span, mask_length) | |
).reshape(batch_size, max_num_masked_span * mask_length) | |
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets | |
# ensure that we cannot have indices larger than sequence_length | |
if spec_aug_mask_idxs.max() > sequence_length - 1: | |
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = ( | |
sequence_length - 1 | |
) | |
# scatter indices to mask | |
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) | |
return torch.from_numpy(spec_aug_mask), spec_aug_mask_idxs | |