|
import random |
|
import copy |
|
import json |
|
import torch |
|
import transformers |
|
import numpy as np |
|
import pickle as pkl |
|
from torch.utils.data import Dataset |
|
from dataclasses import dataclass, field |
|
from typing import Dict, Optional, Sequence, List |
|
|
|
IGNORE_INDEX = -100 |
|
MAX_LENGTH = 2048 |
|
|
|
|
|
@dataclass |
|
class DataArguments: |
|
data_path: str = field(default='./MusicCaps', metadata={"help": "Path to the training data."}) |
|
feat_folder: Optional[str] = field(default='./MusicCaps/music_feat') |
|
|
|
def preprocess_v1(sources: str, tokenizer: transformers.PreTrainedTokenizer, metadata, |
|
prompt_pattern="USER: <Speech><SpeechHere></Speech> Describe the music in detail.\nASSISTANT:\n") -> Dict: |
|
sources = sources.split('\n') |
|
clips, duration, caption = metadata['clips'], metadata['duration'], [] |
|
length = 0 |
|
for l, c in zip(clips, sources): |
|
caption.append( |
|
f'From {int(length / duration * 100)} to {int((length + l) / duration * 100)},' |
|
+ ','.join(c.split(',')[1:]) |
|
) |
|
length += l |
|
|
|
targets = prompt_pattern + '\n'.join(caption) |
|
|
|
targets_left, targets_right = targets.split('<SpeechHere>') |
|
targets_right = tokenizer(targets_right, return_tensors="pt", add_special_tokens=False).input_ids[0] |
|
|
|
sources_left, sources_right = prompt_pattern.split('<SpeechHere>') |
|
sources_left = tokenizer(sources_left, return_tensors="pt", add_special_tokens=False).input_ids[0] |
|
sources_right_length = tokenizer(sources_right, return_tensors="pt", add_special_tokens=False).input_ids.shape[-1] |
|
|
|
sources_right = copy.deepcopy(targets_right) |
|
|
|
targets_left = torch.LongTensor([IGNORE_INDEX] * len(sources_left)) |
|
targets_right[:sources_right_length] = IGNORE_INDEX |
|
|
|
sources_right, targets_right = sources_right[:MAX_LENGTH], targets_right[:MAX_LENGTH] |
|
|
|
return dict(input_ids=(sources_left, sources_right), labels=(targets_left, targets_right)) |
|
|
|
|
|
def preprocess(sources: str, tokenizer: transformers.PreTrainedTokenizer, metadata, |
|
prompt_pattern="USER: <Speech><SpeechHere></Speech> Describe the music in detail.\nASSISTANT:\n") -> Dict: |
|
targets = prompt_pattern + sources |
|
|
|
targets_left, targets_right = targets.split('<SpeechHere>') |
|
targets_right = tokenizer(targets_right, return_tensors="pt", add_special_tokens=False).input_ids[0] |
|
|
|
sources_left, sources_right = prompt_pattern.split('<SpeechHere>') |
|
sources_left = tokenizer(sources_left, return_tensors="pt", add_special_tokens=False).input_ids[0] |
|
sources_right_length = tokenizer(sources_right, return_tensors="pt", add_special_tokens=False).input_ids.shape[-1] |
|
|
|
sources_right = copy.deepcopy(targets_right) |
|
|
|
targets_left = torch.LongTensor([IGNORE_INDEX] * len(sources_left)) |
|
targets_right[:sources_right_length] = IGNORE_INDEX |
|
|
|
sources_right, targets_right = sources_right[:MAX_LENGTH], targets_right[:MAX_LENGTH] |
|
|
|
return dict(input_ids=(sources_left, sources_right), labels=(targets_left, targets_right)) |
|
|
|
|
|
class LazySupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
def __init__(self, data_path, tokenizer, data_args): |
|
super(LazySupervisedDataset, self).__init__() |
|
|
|
self.tokenizer = tokenizer |
|
self.list_data_dict = json.load(open(data_path, "r")) |
|
self.data_args = data_args |
|
|
|
def __len__(self): |
|
return len(self.list_data_dict) |
|
|
|
def __getitem__(self, i): |
|
source = copy.deepcopy(self.list_data_dict[i]) |
|
|
|
feature_path = '{}/{}.pkl'.format(self.data_args.feat_folder, source['id']) |
|
music = pkl.load(open(feature_path, 'rb')) |
|
speech = torch.from_numpy(music['speech']) |
|
audio = torch.from_numpy(music['audio']) |
|
|
|
captions = source['caption'] |
|
if not isinstance(captions, str): |
|
weights = np.asarray([len(c) for c in captions]) |
|
weights = weights / weights.sum() |
|
captions = random.choices(captions, weights, k=1)[0] |
|
|
|
data_dict = preprocess(captions, self.tokenizer, source['meta']) |
|
|
|
data_dict['speeches'] = speech |
|
data_dict['audios'] = audio |
|
return data_dict |
|
|
|
|
|
@dataclass |
|
class DataCollatorForSupervisedDataset(object): |
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
def __call__(self, instances): |
|
input_ids, labels, speeches, audios = tuple( |
|
[instance[key] for instance in instances] for key in ("input_ids", "labels", "speeches", "audios")) |
|
batch = dict(input_ids=input_ids, labels=labels, speeches=speeches, audios=audios) |
|
return batch |
|
|
|
|
|
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) |
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
|
|