|
from typing import Optional, Iterator, Callable, Any |
|
|
|
import torch |
|
from datasets import load_dataset, concatenate_datasets |
|
from transformers import AutoTokenizer |
|
|
|
|
|
def load_text_dataset(tokenizer: AutoTokenizer, |
|
kind: str, |
|
path: str, |
|
name: Optional[str]=None, |
|
data_dir: Optional[str]=None, |
|
data_files: Optional[str]=None, |
|
keep_in_memory: bool=False, |
|
revision: Optional[str]=None, |
|
split: str='train', |
|
num_proc: Optional[int]=None, |
|
format: Optional[Callable|str]=None) -> Any: |
|
assert isinstance(format, str) or callable(format), f'{path=} {format=}' |
|
assert kind == 'base' |
|
|
|
dataset = load_dataset(path=path, |
|
name=name, |
|
data_dir=data_dir, |
|
data_files=data_files, |
|
keep_in_memory=keep_in_memory, |
|
revision=revision, |
|
split=split, |
|
trust_remote_code=True, |
|
num_proc=num_proc) |
|
|
|
EOS_TOKEN = tokenizer.eos_token |
|
|
|
def format_dataset(batch): |
|
nonlocal EOS_TOKEN |
|
nonlocal format |
|
texts: list = [] |
|
rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())] |
|
|
|
if callable(format): |
|
for row in rows: |
|
|
|
text = format(row) |
|
|
|
if not text: |
|
text = '[NONE]' |
|
|
|
text += EOS_TOKEN |
|
texts.append(text) |
|
else: |
|
for row in rows: |
|
|
|
text = format.format(**row) |
|
|
|
if not text: |
|
text = '[NONE]' |
|
|
|
text += EOS_TOKEN |
|
texts.append(text) |
|
|
|
return {'text': texts} |
|
|
|
dataset = dataset.map(format_dataset, batched=True) |
|
return dataset |
|
|
|
|
|
def load_chat_dataset(tokenizer: AutoTokenizer, |
|
kind: str, |
|
path: str, |
|
name: Optional[str]=None, |
|
data_dir: Optional[str]=None, |
|
data_files: Optional[str]=None, |
|
keep_in_memory: bool=False, |
|
revision: Optional[str]=None, |
|
split: str='train', |
|
num_proc: Optional[int]=None, |
|
field: Optional[str]=None, |
|
transform: Optional[Callable]=None) -> Any: |
|
assert kind == 'instruct' |
|
|
|
dataset = load_dataset(path=path, |
|
name=name, |
|
data_dir=data_dir, |
|
data_files=data_files, |
|
keep_in_memory=keep_in_memory, |
|
revision=revision, |
|
split=split, |
|
trust_remote_code=True, |
|
num_proc=num_proc) |
|
|
|
EOS_TOKEN = tokenizer.eos_token |
|
|
|
def format_dataset(batch): |
|
nonlocal EOS_TOKEN |
|
nonlocal tokenizer |
|
nonlocal field |
|
nonlocal transform |
|
texts: list = [] |
|
rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())] |
|
|
|
if callable(transform): |
|
for row in rows: |
|
if field: |
|
messages = transform(row[field]) |
|
else: |
|
messages = transform(row) |
|
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False) |
|
text += EOS_TOKEN |
|
texts.append(text) |
|
else: |
|
for row in rows: |
|
if field: |
|
messages = row[field] |
|
else: |
|
raise ValueError(field) |
|
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False) |
|
text += EOS_TOKEN |
|
texts.append(text) |
|
|
|
return {'text': texts} |
|
|
|
dataset = dataset.map(format_dataset, batched=True) |
|
return dataset |
|
|