import gc from typing import Optional, Iterator, Callable import torch from datasets import load_dataset from litgpt.tokenizer import Tokenizer from transformers import AutoTokenizer def batch_text_iterator(kind: str, path: str, name: Optional[str]=None, data_dir: Optional[str]=None, data_files: Optional[str]=None, keep_in_memory: bool=False, revision: Optional[str]=None, split: str='train', num_proc: Optional[int]=None, format: Optional[Callable|str]=None) -> Iterator[str]: assert isinstance(format, str) or callable(format), f'{path=} {format=}' assert kind == 'base' dataset = load_dataset(path=path, name=name, data_dir=data_dir, data_files=data_files, keep_in_memory=keep_in_memory, revision=revision, split=split, trust_remote_code=True, num_proc=num_proc) if callable(format): for row in dataset: text = format(row) if not text: continue yield text else: for row in dataset: text = format.format(**row) if not text: continue yield text del dataset gc.collect() def batch_chat_iterator(kind: str, path: str, name: Optional[str]=None, data_dir: Optional[str]=None, data_files: Optional[str]=None, keep_in_memory: bool=False, revision: Optional[str]=None, split: str='train', num_proc: Optional[int]=None, field: Optional[str]=None, transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]: assert kind == 'instruct' dataset = load_dataset(path=path, name=name, data_dir=data_dir, data_files=data_files, keep_in_memory=keep_in_memory, revision=revision, split=split, trust_remote_code=True, num_proc=num_proc) if callable(transform): for row in dataset: if field: messages = transform(row[field]) else: messages = transform(row) if not messages: continue yield messages else: for row in dataset: if field: messages = row[field] else: raise ValueError(field) if not messages: continue yield messages del dataset gc.collect() # NOTE: used only by tokenizer trainer def batch_dataset_iterator(dataset_config: dict) -> Iterator[str]: if dataset_config['kind'] == 'base': for text in batch_text_iterator(**dataset_config): yield text elif dataset_config['kind'] == 'instruct': for messages in batch_chat_iterator(**dataset_config): text = '\n'.join(n['content'] for n in messages) yield text def tokenize_text_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: for text in batch_text_iterator(**dataset_config): text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) yield text_ids def tokenize_chat_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: for messages in batch_chat_iterator(**dataset_config): text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) yield text_ids def tokenize_fn(dataset_config: dict, min_len: int, max_len: int, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: if dataset_config['kind'] == 'base': for text in batch_text_iterator(**dataset_config): try: text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) except Exception as e: print(f'Skip base raw: {e=} {type(text)=} {text=}') continue if min_len <= len(text_ids) <= max_len: yield text_ids elif dataset_config['kind'] == 'instruct': for messages in batch_chat_iterator(**dataset_config): try: text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) except Exception as e: print(f'Skip instruct row: {e=} {type(messages)=} {messages=}') continue if min_len <= len(text_ids) <= max_len: yield text_ids else: raise ValueError(dataset_config['kind'])