|
--- |
|
language: |
|
- ru |
|
license: mit |
|
datasets: |
|
- misterkirill/ru-wikipedia |
|
tags: |
|
- pytorch |
|
- neural-memory |
|
- titan |
|
- text-generation |
|
--- |
|
|
|
# Neural Memory Model for Russian Text Generation |
|
|
|
This model implements a neural memory architecture for Russian text generation using PyTorch and the Titans library. The architecture is based on the implementation from [lucidrains/titans-pytorch](https://github.com/lucidrains/titans-pytorch). |
|
|
|
## Model Description |
|
|
|
The model uses a Transformer architecture enhanced with neural memory capabilities from the Titans library for improved context handling and long-range dependencies in text generation. |
|
|
|
### Architecture Source |
|
|
|
The core architecture is derived from the [Titans PyTorch implementation](https://github.com/lucidrains/titans-pytorch) by Phil Wang ([@lucidrains](https://github.com/lucidrains)). The original implementation provides the following key components that we utilize: |
|
- Memory-enhanced Transformer architecture |
|
- Flexible attention mechanisms |
|
- Neural memory layers |
|
|
|
### Key Features |
|
|
|
- Neural memory architecture with customizable depth and size |
|
- Sliding window attention mechanism |
|
- Gradient accumulation for stable training |
|
- CUDA-optimized implementation |
|
|
|
## Requirements |
|
|
|
### Environment |
|
|
|
- Python: 3.9.21 |
|
- CUDA: 11.8 |
|
- GPU with at least 16GB VRAM recommended |
|
|
|
### Key Dependencies |
|
``` |
|
Python version: 3.9.21 |
|
CUDA version: 11.8 |
|
|
|
Requirements: |
|
adam-atan2-pytorch==0.1.18 |
|
datasets==3.2.0 |
|
nvidia-cuda-cupti-cu12==12.4.127 |
|
nvidia-cuda-nvrtc-cu12==12.4.127 |
|
nvidia-cuda-runtime-cu12==12.4.127 |
|
nvidia-cudnn-cu12==9.1.0.70 |
|
nvidia-cufft-cu12==11.2.1.3 |
|
nvidia-curand-cu12==10.3.5.147 |
|
nvidia-cusolver-cu12==11.6.1.9 |
|
nvidia-cusparselt-cu12==0.6.2 |
|
nvidia-nccl-cu12==2.21.5 |
|
nvidia-nvtx-cu12==12.4.127 |
|
titans-pytorch==0.3.25 |
|
torchaudio==2.5.1 |
|
torchvision==0.20.1 |
|
transformers==4.48.3 |
|
triton==3.1.0 |
|
wandb==0.19.6 |
|
``` |
|
|
|
# Example |
|
The repository includes complete training and inference code. Key components: |
|
|
|
|
|
- Data preprocessing (WikiDatasetPreprocessor) |
|
- Custom dataset implementation (WikiTextDataset) |
|
- Training loop with gradient accumulation |
|
- Validation and checkpointing |
|
|
|
## Example Code |
|
```python |
|
import os |
|
import warnings |
|
from pathlib import Path |
|
from typing import List, Dict, Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import ( |
|
GPT2TokenizerFast, |
|
PreTrainedModel, |
|
PreTrainedTokenizer, |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
PretrainedConfig, |
|
GenerationMixin, |
|
pipeline |
|
) |
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
from huggingface_hub import HfApi, login |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from adam_atan2_pytorch import AdoptAtan2 |
|
|
|
from titans_pytorch import ( |
|
MemoryAsContextTransformer, |
|
MemoryMLP, |
|
MemoryAttention |
|
) |
|
|
|
# Отключаем предупреждения |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
torch._dynamo.config.suppress_errors = True |
|
torch._dynamo.config.cache_size_limit = 100000 |
|
torch._dynamo.config.disable = True |
|
|
|
# Настройки CUDA |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32' |
|
|
|
# Константы |
|
repo_id = 'Grpp/memory-transformer-ru' |
|
NUM_BATCHES = int(1e5) |
|
BATCH_SIZE = 4 |
|
GRADIENT_ACCUMULATE_EVERY = 4 |
|
LEARNING_RATE = 2e-4 |
|
VALIDATE_EVERY = 100 |
|
GENERATE_EVERY = 500 |
|
PRIME_LENGTH = 100 |
|
GENERATE_LENGTH = 512 |
|
SHOULD_GENERATE = True |
|
SEQ_LEN = 512 |
|
|
|
# Константы для нейронной памяти |
|
NEURAL_MEMORY_DEPTH = 2 |
|
NUM_PERSIST_MEM = 4 |
|
NUM_LONGTERM_MEM = 4 |
|
NEURAL_MEM_LAYERS = (2, 4, 6) |
|
NEURAL_MEM_GATE_ATTN_OUTPUT = False |
|
NEURAL_MEM_MOMENTUM = True |
|
NEURAL_MEM_MOMENTUM_ORDER = 1 |
|
NEURAL_MEM_QK_NORM = True |
|
NEURAL_MEM_MAX_LR = 1e-1 |
|
USE_MEM_ATTENTION_MODEL = False |
|
WINDOW_SIZE = 32 |
|
NEURAL_MEM_SEGMENT_LEN = 4 |
|
NEURAL_MEM_BATCH_SIZE = 128 |
|
SLIDING_WINDOWS = True |
|
STORE_ATTN_POOL_CHUNKS = True |
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True |
|
NEURAL_MEM_WEIGHT_RESIDUAL = True |
|
|
|
|
|
class MemoryTransformerConfig(PretrainedConfig): |
|
model_type = "memory_transformer" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=50257, |
|
dim=384, |
|
depth=8, |
|
segment_len=32, |
|
num_persist_mem=4, |
|
num_longterm_mem=4, |
|
neural_mem_layers=(2, 4, 6), |
|
pad_token_id=0, |
|
bos_token_id=1, |
|
eos_token_id=2, |
|
**kwargs |
|
): |
|
self.vocab_size = vocab_size |
|
self.dim = dim |
|
self.depth = depth |
|
self.segment_len = segment_len |
|
self.num_persist_mem = num_persist_mem |
|
self.num_longterm_mem = num_longterm_mem |
|
self.neural_mem_layers = neural_mem_layers |
|
super().__init__( |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
**kwargs |
|
) |
|
|
|
|
|
class MemoryTransformerForCausalLM(PreTrainedModel, GenerationMixin): |
|
config_class = MemoryTransformerConfig |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
neural_memory_model = ( |
|
MemoryAttention(dim=64) if USE_MEM_ATTENTION_MODEL |
|
else MemoryMLP(dim=64, depth=NEURAL_MEMORY_DEPTH) |
|
) |
|
|
|
self.transformer = MemoryAsContextTransformer( |
|
num_tokens=config.vocab_size, |
|
dim=config.dim, |
|
depth=config.depth, |
|
segment_len=config.segment_len, |
|
num_persist_mem_tokens=config.num_persist_mem, |
|
num_longterm_mem_tokens=config.num_longterm_mem, |
|
neural_memory_layers=config.neural_mem_layers, |
|
neural_memory_segment_len=NEURAL_MEM_SEGMENT_LEN, |
|
neural_memory_batch_size=NEURAL_MEM_BATCH_SIZE, |
|
neural_mem_gate_attn_output=NEURAL_MEM_GATE_ATTN_OUTPUT, |
|
neural_mem_weight_residual=NEURAL_MEM_WEIGHT_RESIDUAL, |
|
use_flex_attn=True, |
|
sliding_window_attn=SLIDING_WINDOWS, |
|
neural_memory_model=neural_memory_model, |
|
neural_memory_kwargs=dict( |
|
dim_head=64, |
|
heads=4, |
|
attn_pool_chunks=STORE_ATTN_POOL_CHUNKS, |
|
qk_rmsnorm=NEURAL_MEM_QK_NORM, |
|
momentum=NEURAL_MEM_MOMENTUM, |
|
momentum_order=NEURAL_MEM_MOMENTUM_ORDER, |
|
default_step_transform_max_lr=NEURAL_MEM_MAX_LR, |
|
use_accelerated_scan=True, |
|
per_parameter_lr_modulation=MEMORY_MODEL_PER_LAYER_LEARNED_LR |
|
) |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
outputs = self.transformer(input_ids) |
|
|
|
if labels is not None: |
|
loss = self.transformer(input_ids, return_loss=True) |
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=outputs, |
|
past_key_values=None, |
|
hidden_states=None, |
|
attentions=None, |
|
cross_attentions=None |
|
) |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=None, |
|
logits=outputs, |
|
past_key_values=None, |
|
hidden_states=None, |
|
attentions=None, |
|
cross_attentions=None |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past=None, |
|
attention_mask=None, |
|
**kwargs |
|
): |
|
if past: |
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"past_key_values": past, |
|
"attention_mask": attention_mask, |
|
} |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
|
|
def setup_custom_model(): |
|
"""Регистрация кастомной модели""" |
|
AutoConfig.register("memory_transformer", MemoryTransformerConfig) |
|
AutoModelForCausalLM.register(MemoryTransformerConfig, MemoryTransformerForCausalLM) |
|
|
|
|
|
def generate_example(model, tokenizer, text, max_length=100): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
model.eval() |
|
|
|
input_ids = tokenizer.encode(text, return_tensors="pt").to(device) |
|
attention_mask = torch.ones_like(input_ids, device=device) |
|
|
|
print(f"Model device: {next(model.parameters()).device}") |
|
print(f"Input device: {input_ids.device}") |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
temperature=0.7, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.manual_seed(42) |
|
torch.cuda.manual_seed_all(42) |
|
|
|
setup_custom_model() |
|
config = AutoConfig.from_pretrained(repo_id) |
|
model = AutoModelForCausalLM.from_pretrained(repo_id) |
|
tokenizer = AutoTokenizer.from_pretrained(repo_id) |
|
|
|
test_text = "Московский кремль является" |
|
generated_text = generate_example(model, tokenizer, test_text) |
|
print(generated_text) |
|
``` |
|
|
|
|
|
## Finetine Code |
|
|
|
```python |
|
import os |
|
import torch |
|
from pathlib import Path |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
|
from tqdm import tqdm |
|
from adam_atan2_pytorch import AdoptAtan2 |
|
|
|
# Импортируем классы из кода обучения |
|
from run_train_pep8 import ( |
|
WikiDatasetPreprocessor, |
|
WikiTextDataset, |
|
create_dataloaders, |
|
cycle |
|
) # From Train Code |
|
|
|
from test_load import setup_custom_model # From Example Code |
|
|
|
# Настройки CUDA |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32' |
|
|
|
# Константы для файнтьюнинга |
|
BATCH_SIZE = 2 |
|
GRADIENT_ACCUMULATE_EVERY = 2 |
|
LEARNING_RATE = 1e-5 |
|
NUM_EPOCHS = 3 |
|
STEPS_PER_EPOCH = 1000 # Количество шагов на эпоху |
|
SEQ_LEN = 256 |
|
PROCESSED_DATA_DIR = 'processed_data' |
|
CACHE_DIR = 'cache' |
|
REPO_ID = 'Grpp/memory-transformer-ru' |
|
|
|
def finetune_model( |
|
model, |
|
train_loader, |
|
val_loader, |
|
num_epochs, |
|
device, |
|
save_path='finetuned_model' |
|
): |
|
"""Файнтьюнинг модели.""" |
|
|
|
model = model.to(device) |
|
optimizer = AdoptAtan2(model.parameters(), lr=LEARNING_RATE) |
|
|
|
best_val_loss = float('inf') |
|
|
|
for epoch in range(num_epochs): |
|
model.train() |
|
total_train_loss = 0 |
|
train_steps = 0 |
|
|
|
# Прогресс-бар для фиксированного количества шагов |
|
train_pbar = tqdm(range(STEPS_PER_EPOCH), |
|
desc=f'Epoch {epoch+1}/{num_epochs} [Train]') |
|
|
|
for step in train_pbar: |
|
total_loss = 0 |
|
|
|
# Градиентное накопление |
|
for _ in range(GRADIENT_ACCUMULATE_EVERY): |
|
batch = next(train_loader) |
|
batch = batch.to(device) |
|
|
|
# Получаем входные данные и метки |
|
inputs = batch[:, :-1] |
|
labels = batch[:, 1:] |
|
|
|
# Прямой проход |
|
outputs = model(input_ids=inputs, labels=labels) |
|
loss = outputs.loss / GRADIENT_ACCUMULATE_EVERY |
|
|
|
# Обратное распространение |
|
loss.backward() |
|
total_loss += loss.item() |
|
|
|
# Обновление параметров |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
total_train_loss += total_loss |
|
train_steps += 1 |
|
|
|
# Обновление прогресс-бара |
|
train_pbar.set_postfix({ |
|
'loss': f'{total_loss:.4f}', |
|
'avg_loss': f'{total_train_loss/train_steps:.4f}' |
|
}) |
|
|
|
# Валидация каждые 100 шагов |
|
if step % 100 == 0: |
|
model.eval() |
|
val_loss = 0 |
|
val_steps = 0 |
|
|
|
with torch.no_grad(): |
|
for _ in range(10): # Ограничиваем количество валидационных шагов |
|
val_batch = next(val_loader) |
|
val_batch = val_batch.to(device) |
|
|
|
val_inputs = val_batch[:, :-1] |
|
val_labels = val_batch[:, 1:] |
|
|
|
val_outputs = model(input_ids=val_inputs, labels=val_labels) |
|
val_loss += val_outputs.loss.item() |
|
val_steps += 1 |
|
|
|
avg_val_loss = val_loss / val_steps |
|
|
|
print(f"\nValidation loss: {avg_val_loss:.4f}") |
|
|
|
# Сохраняем лучшую модель |
|
if avg_val_loss < best_val_loss: |
|
best_val_loss = avg_val_loss |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': best_val_loss, |
|
}, f'{save_path}_best.pt') |
|
|
|
model.train() |
|
|
|
# Сохраняем чекпойнт после каждой эпохи |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': total_train_loss / train_steps, |
|
}, f'{save_path}_epoch_{epoch}.pt') |
|
|
|
print(f"\nEpoch {epoch+1} completed. Average loss: {total_train_loss/train_steps:.4f}") |
|
|
|
return model |
|
|
|
def main(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
# Загружаем и подготавливаем данные |
|
processed_data_path = Path(PROCESSED_DATA_DIR) / 'processed_wiki.pt' |
|
|
|
if not processed_data_path.exists(): |
|
print("Processing dataset...") |
|
preprocessor = WikiDatasetPreprocessor(CACHE_DIR, PROCESSED_DATA_DIR) |
|
preprocessor.process_and_save(max_articles=10000) |
|
|
|
print("Creating dataloaders...") |
|
train_loader, val_loader = create_dataloaders( |
|
processed_data_path, |
|
batch_size=BATCH_SIZE, |
|
seq_len=SEQ_LEN |
|
) |
|
|
|
train_loader = cycle(train_loader) |
|
val_loader = cycle(val_loader) |
|
|
|
# Загружаем предобученную модель |
|
print("Loading pretrained model...") |
|
setup_custom_model() |
|
config = AutoConfig.from_pretrained(REPO_ID) |
|
model = AutoModelForCausalLM.from_pretrained(REPO_ID) |
|
|
|
print("Starting finetuning...") |
|
# Файнтьюним модель |
|
model = finetune_model( |
|
model, |
|
train_loader, |
|
val_loader, |
|
NUM_EPOCHS, |
|
device |
|
) |
|
|
|
# Сохраняем финальную версию модели |
|
print("Saving final model...") |
|
model.save_pretrained('final_finetuned_model') |
|
|
|
return model |
|
|
|
if __name__ == "__main__": |
|
torch.manual_seed(42) |
|
torch.cuda.manual_seed_all(42) |
|
torch.backends.cudnn.benchmark = True |
|
|
|
try: |
|
model = main() |
|
print("Finetuning completed successfully!") |
|
except Exception as e: |
|
print(f"An error occurred: {str(e)}") |
|
``` |
|
|
|
# Training |
|
|
|
The model was trained on a cleaned subset of Russian Wikipedia articles using the following parameters: |
|
|
|
|
|
Batch size: 4 |
|
Sequence length: 512 |
|
Learning rate: 2e-4 |
|
Gradient accumulation steps: 4 |
|
Neural memory depth: 2 |
|
Window size: 32 |
|
|
|
## Train Code |
|
```python |
|
import json |
|
import os |
|
import random |
|
import re |
|
from pathlib import Path |
|
from typing import List, Dict |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from transformers import GPT2TokenizerFast |
|
from tqdm import tqdm |
|
from datasets import load_dataset |
|
from adam_atan2_pytorch import AdoptAtan2 |
|
from titans_pytorch import ( |
|
MemoryAsContextTransformer, |
|
MemoryMLP, |
|
MemoryAttention |
|
) |
|
|
|
# CUDA memory settings |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32' |
|
|
|
# Training constants |
|
NUM_BATCHES = int(1e5) |
|
BATCH_SIZE = 4 |
|
GRADIENT_ACCUMULATE_EVERY = 4 |
|
LEARNING_RATE = 2e-4 |
|
VALIDATE_EVERY = 100 |
|
GENERATE_EVERY = 500 |
|
PRIME_LENGTH = 100 |
|
GENERATE_LENGTH = 512 |
|
SHOULD_GENERATE = True |
|
SEQ_LEN = 512 |
|
|
|
# Neural memory constants |
|
NEURAL_MEMORY_DEPTH = 2 |
|
NUM_PERSIST_MEM = 4 |
|
NUM_LONGTERM_MEM = 4 |
|
NEURAL_MEM_LAYERS = (2, 4, 6) |
|
NEURAL_MEM_GATE_ATTN_OUTPUT = False |
|
NEURAL_MEM_MOMENTUM = True |
|
NEURAL_MEM_MOMENTUM_ORDER = 1 |
|
NEURAL_MEM_QK_NORM = True |
|
NEURAL_MEM_MAX_LR = 1e-1 |
|
USE_MEM_ATTENTION_MODEL = False |
|
WINDOW_SIZE = 32 |
|
NEURAL_MEM_SEGMENT_LEN = 4 |
|
NEURAL_MEM_BATCH_SIZE = 128 |
|
SLIDING_WINDOWS = True |
|
STORE_ATTN_POOL_CHUNKS = True |
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True |
|
NEURAL_MEM_WEIGHT_RESIDUAL = True |
|
|
|
# Initialize tokenizer |
|
tokenizer = GPT2TokenizerFast.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') |
|
|
|
|
|
class WikiDatasetPreprocessor: |
|
def __init__(self, cache_dir: str = 'cache', output_dir: str = 'processed_data'): |
|
self.cache_dir = Path(cache_dir) |
|
self.output_dir = Path(output_dir) |
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
self.tokenizer = GPT2TokenizerFast.from_pretrained( |
|
'sberbank-ai/rugpt3small_based_on_gpt2' |
|
) |
|
|
|
def load_wiki_dataset(self): |
|
"""Загрузка датасета из Hugging Face.""" |
|
print("Loading Wikipedia dataset...") |
|
dataset = load_dataset( |
|
"misterkirill/ru-wikipedia", |
|
cache_dir=str(self.cache_dir) |
|
) |
|
print(f"Dataset loaded. Size: {len(dataset['train'])} articles") |
|
return dataset |
|
|
|
def clean_text(self, text: str) -> str: |
|
"""Базовая очистка текста.""" |
|
return ' '.join(text.split()) |
|
|
|
def process_wiki_article(self, text: str) -> List[str]: |
|
"""Обработка одной статьи из википедии.""" |
|
processed_chunks = [] |
|
clean_text = self.clean_text(text) |
|
tokens = self.tokenizer.encode(clean_text) |
|
|
|
chunk_size = 256 |
|
stride = 192 |
|
|
|
for i in range(0, len(tokens), stride): |
|
chunk = tokens[i:i + chunk_size] |
|
if len(chunk) > 50: |
|
processed_chunks.append(chunk) |
|
|
|
return processed_chunks |
|
|
|
def process_and_save( |
|
self, |
|
batch_size: int = 1000, |
|
test_size: float = 0.1, |
|
max_articles: int = 10000 |
|
): |
|
"""Обработка статей из датасета и сохранение результатов.""" |
|
dataset = self.load_wiki_dataset() |
|
total_articles = min(len(dataset['train']), max_articles) |
|
print(f"Processing {total_articles} articles out of {len(dataset['train'])}") |
|
|
|
all_chunks = [] |
|
for i in tqdm(range(0, total_articles, batch_size), desc="Processing articles"): |
|
batch = dataset['train'][i:i + batch_size] |
|
for text in batch['text']: |
|
chunks = self.process_wiki_article(text) |
|
all_chunks.extend(chunks) |
|
|
|
if len(all_chunks) > 50000: |
|
break |
|
|
|
if len(all_chunks) > 50000: |
|
break |
|
|
|
print(f"Total chunks created: {len(all_chunks)}") |
|
|
|
random.seed(42) |
|
random.shuffle(all_chunks) |
|
|
|
test_size = int(len(all_chunks) * test_size) |
|
train_chunks = all_chunks[:-test_size] |
|
test_chunks = all_chunks[-test_size:] |
|
|
|
print(f"Saving {len(train_chunks)} training chunks and {len(test_chunks)} test chunks...") |
|
torch.save( |
|
{ |
|
'train': train_chunks, |
|
'test': test_chunks |
|
}, |
|
self.output_dir / 'processed_wiki.pt' |
|
) |
|
|
|
|
|
class WikiTextDataset(Dataset): |
|
def __init__(self, chunks: List[List[int]], seq_len: int = 512): |
|
self.chunks = chunks |
|
self.seq_len = seq_len |
|
|
|
def __len__(self): |
|
return len(self.chunks) |
|
|
|
def __getitem__(self, idx): |
|
chunk = self.chunks[idx] |
|
if len(chunk) < self.seq_len + 1: |
|
chunk = chunk + [50256] * (self.seq_len + 1 - len(chunk)) |
|
else: |
|
chunk = chunk[:self.seq_len + 1] |
|
return torch.tensor(chunk, device='cuda').long() |
|
|
|
|
|
def create_dataloaders( |
|
processed_data_path: str, |
|
batch_size: int = 4, |
|
seq_len: int = 512, |
|
train_test_split: float = 0.9 |
|
) -> tuple: |
|
"""Создание загрузчиков данных для обучения и валидации.""" |
|
print(f"Loading processed data from {processed_data_path}") |
|
data = torch.load(processed_data_path) |
|
train_chunks = data['train'] |
|
test_chunks = data['test'] |
|
|
|
train_dataset = WikiTextDataset(train_chunks, seq_len) |
|
test_dataset = WikiTextDataset(test_chunks, seq_len) |
|
|
|
print(f"Created datasets with {len(train_dataset)} training and " |
|
f"{len(test_dataset)} test samples") |
|
|
|
train_loader = DataLoader( |
|
train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=0, |
|
pin_memory=False |
|
) |
|
|
|
val_loader = DataLoader( |
|
test_dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=0, |
|
pin_memory=False |
|
) |
|
|
|
return train_loader, val_loader |
|
|
|
|
|
def cycle(loader): |
|
"""Бесконечный итератор по загрузчику данных.""" |
|
while True: |
|
for data in loader: |
|
yield data |
|
|
|
|
|
def create_model(): |
|
"""Создание модели нейронной сети.""" |
|
try: |
|
if USE_MEM_ATTENTION_MODEL: |
|
neural_memory_model = MemoryAttention(dim=64) |
|
else: |
|
neural_memory_model = MemoryMLP(dim=64, depth=NEURAL_MEMORY_DEPTH) |
|
|
|
model = MemoryAsContextTransformer( |
|
num_tokens=len(tokenizer), |
|
dim=384, |
|
depth=8, |
|
segment_len=WINDOW_SIZE, |
|
num_persist_mem_tokens=NUM_PERSIST_MEM, |
|
num_longterm_mem_tokens=NUM_LONGTERM_MEM, |
|
neural_memory_layers=NEURAL_MEM_LAYERS, |
|
neural_memory_segment_len=NEURAL_MEM_SEGMENT_LEN, |
|
neural_memory_batch_size=NEURAL_MEM_BATCH_SIZE, |
|
neural_mem_gate_attn_output=NEURAL_MEM_GATE_ATTN_OUTPUT, |
|
neural_mem_weight_residual=NEURAL_MEM_WEIGHT_RESIDUAL, |
|
use_flex_attn=True, |
|
sliding_window_attn=SLIDING_WINDOWS, |
|
neural_memory_model=neural_memory_model, |
|
neural_memory_kwargs=dict( |
|
dim_head=64, |
|
heads=4, |
|
attn_pool_chunks=STORE_ATTN_POOL_CHUNKS, |
|
qk_rmsnorm=NEURAL_MEM_QK_NORM, |
|
momentum=NEURAL_MEM_MOMENTUM, |
|
momentum_order=NEURAL_MEM_MOMENTUM_ORDER, |
|
default_step_transform_max_lr=NEURAL_MEM_MAX_LR, |
|
use_accelerated_scan=True, |
|
per_parameter_lr_modulation=MEMORY_MODEL_PER_LAYER_LEARNED_LR |
|
) |
|
).cuda() |
|
|
|
assert next(model.parameters()).is_cuda, "Model is not on CUDA" |
|
return model |
|
|
|
except Exception as e: |
|
print(f"Error creating model: {e}") |
|
raise e |
|
|
|
|
|
def train_model(model, train_loader, val_loader, num_batches=int(1e4)): |
|
"""Обучение модели.""" |
|
optim = AdoptAtan2(model.parameters(), lr=2e-4) |
|
torch.cuda.empty_cache() |
|
pbar = tqdm(range(num_batches), desc='Training') |
|
running_loss = 0.0 |
|
|
|
try: |
|
for i in pbar: |
|
model.train() |
|
total_loss = 0 |
|
|
|
for __ in range(4): |
|
batch = next(train_loader) |
|
loss = model(batch, return_loss=True) |
|
loss = loss / 4 |
|
loss.backward() |
|
total_loss += loss.item() |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
|
optim.step() |
|
optim.zero_grad() |
|
|
|
if i % 100 == 0: |
|
torch.cuda.empty_cache() |
|
|
|
avg_loss = total_loss |
|
running_loss = 0.9 * running_loss + 0.1 * avg_loss if i > 0 else avg_loss |
|
|
|
pbar.set_postfix({ |
|
'loss': f'{running_loss:.4f}', |
|
'batch_loss': f'{avg_loss:.4f}' |
|
}) |
|
|
|
if i % 100 == 0: |
|
model.eval() |
|
with torch.no_grad(): |
|
val_batch = next(val_loader) |
|
val_loss = model(val_batch, return_loss=True) |
|
pbar.set_postfix({ |
|
'train_loss': f'{running_loss:.4f}', |
|
'val_loss': f'{val_loss.item():.4f}' |
|
}) |
|
|
|
if i % 1000 == 0 and i > 0: |
|
torch.save({ |
|
'epoch': i, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optim.state_dict(), |
|
'loss': running_loss, |
|
}, f'checkpoint_{i}.pt') |
|
|
|
except KeyboardInterrupt: |
|
print("\nTraining interrupted by user") |
|
except Exception as e: |
|
print(f"\nTraining stopped due to error: {e}") |
|
raise e |
|
|
|
return model |
|
|
|
|
|
def main(): |
|
"""Основная функция программы.""" |
|
try: |
|
if not torch.cuda.is_available(): |
|
raise RuntimeError("CUDA is not available. This code requires GPU.") |
|
|
|
print(f"Using CUDA device: {torch.cuda.get_device_name(0)}") |
|
|
|
BATCH_SIZE = 4 |
|
SEQ_LEN = 512 |
|
CACHE_DIR = 'cache' |
|
PROCESSED_DATA_DIR = 'processed_data' |
|
NUM_BATCHES = 10000 |
|
|
|
preprocessor = WikiDatasetPreprocessor(CACHE_DIR, PROCESSED_DATA_DIR) |
|
processed_data_path = Path(PROCESSED_DATA_DIR) / 'processed_wiki.pt' |
|
|
|
if not processed_data_path.exists(): |
|
print("Processing Wikipedia dataset...") |
|
preprocessor.process_and_save(max_articles=10000) |
|
|
|
train_loader, val_loader = create_dataloaders( |
|
processed_data_path, |
|
batch_size=BATCH_SIZE, |
|
seq_len=SEQ_LEN |
|
) |
|
|
|
train_loader = cycle(train_loader) |
|
val_loader = cycle(val_loader) |
|
|
|
model = create_model() |
|
model = train_model(model, train_loader, val_loader, num_batches=NUM_BATCHES) |
|
|
|
torch.save(model.state_dict(), 'final_model.pt') |
|
return model, train_loader, val_loader |
|
|
|
except Exception as e: |
|
print(f"Error in main: {e}") |
|
raise e |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.manual_seed(42) |
|
torch.cuda.manual_seed_all(42) |
|
torch.backends.cudnn.benchmark = True |
|
model, train_loader, val_loader = main() |
|
``` |
|
|
|
# License |
|
|
|
This project is licensed under the MIT License. See LICENSE file for details. |
|
|
|
|
|
# Citation |
|
|
|
If you use this model in your research, please cite: |
|
```bibtex |
|
@software{neural_memory_model, |
|
title = {Neural Memory Model for Russian Text Generation}, |
|
year = {2025}, |
|
url = {https://huggingface.co/Grpp/memory-transformer-ru} |
|
} |
|
``` |