Grpp commited on
Commit
9fccb26
·
verified ·
1 Parent(s): 1034d32

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +853 -0
README.md ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ru
4
+ license: mit
5
+ datasets:
6
+ - misterkirill/ru-wikipedia
7
+ tags:
8
+ - pytorch
9
+ - neural-memory
10
+ - titan
11
+ - text-generation
12
+ ---
13
+
14
+ # Neural Memory Model for Russian Text Generation
15
+
16
+ This model implements a neural memory architecture for Russian text generation using PyTorch and the Titans library. The architecture is based on the implementation from [lucidrains/titans-pytorch](https://github.com/lucidrains/titans-pytorch).
17
+
18
+ ## Model Description
19
+
20
+ The model uses a Transformer architecture enhanced with neural memory capabilities from the Titans library for improved context handling and long-range dependencies in text generation.
21
+
22
+ ### Architecture Source
23
+
24
+ The core architecture is derived from the [Titans PyTorch implementation](https://github.com/lucidrains/titans-pytorch) by Phil Wang ([@lucidrains](https://github.com/lucidrains)). The original implementation provides the following key components that we utilize:
25
+ - Memory-enhanced Transformer architecture
26
+ - Flexible attention mechanisms
27
+ - Neural memory layers
28
+
29
+ ### Key Features
30
+
31
+ - Neural memory architecture with customizable depth and size
32
+ - Sliding window attention mechanism
33
+ - Gradient accumulation for stable training
34
+ - CUDA-optimized implementation
35
+
36
+ ## Requirements
37
+
38
+ ### Environment
39
+
40
+ - Python: 3.9.21
41
+ - CUDA: 11.8
42
+ - GPU with at least 16GB VRAM recommended
43
+
44
+ ### Key Dependencies
45
+ ```
46
+ Python version: 3.9.21
47
+ CUDA version: 11.8
48
+
49
+ Requirements:
50
+ adam-atan2-pytorch==0.1.18
51
+ datasets==3.2.0
52
+ nvidia-cuda-cupti-cu12==12.4.127
53
+ nvidia-cuda-nvrtc-cu12==12.4.127
54
+ nvidia-cuda-runtime-cu12==12.4.127
55
+ nvidia-cudnn-cu12==9.1.0.70
56
+ nvidia-cufft-cu12==11.2.1.3
57
+ nvidia-curand-cu12==10.3.5.147
58
+ nvidia-cusolver-cu12==11.6.1.9
59
+ nvidia-cusparselt-cu12==0.6.2
60
+ nvidia-nccl-cu12==2.21.5
61
+ nvidia-nvtx-cu12==12.4.127
62
+ titans-pytorch==0.3.25
63
+ torchaudio==2.5.1
64
+ torchvision==0.20.1
65
+ transformers==4.48.3
66
+ triton==3.1.0
67
+ wandb==0.19.6
68
+ ```
69
+
70
+ # Example
71
+ The repository includes complete training and inference code. Key components:
72
+
73
+
74
+ - Data preprocessing (WikiDatasetPreprocessor)
75
+ - Custom dataset implementation (WikiTextDataset)
76
+ - Training loop with gradient accumulation
77
+ - Validation and checkpointing
78
+
79
+ ## Example Code
80
+ ```python
81
+ import os
82
+ import re
83
+ import json
84
+ import random
85
+ from tqdm import tqdm
86
+ import numpy as np
87
+ from pathlib import Path
88
+
89
+ import torch
90
+ from torch import nn
91
+ from torch.utils.data import DataLoader, Dataset
92
+ from transformers import GPT2TokenizerFast
93
+ from adam_atan2_pytorch import AdoptAtan2
94
+
95
+ from titans_pytorch import (
96
+ MemoryAsContextTransformer,
97
+ MemoryMLP,
98
+ MemoryAttention
99
+ )
100
+
101
+ import os
102
+ import json
103
+ import random
104
+ from pathlib import Path
105
+ from typing import List, Dict
106
+ import numpy as np
107
+ from tqdm import tqdm
108
+ from datasets import load_dataset
109
+ import torch
110
+ from torch.utils.data import Dataset, DataLoader
111
+ from transformers import GPT2TokenizerFast
112
+
113
+ # Добавляем настройки для управления памятью CUDA
114
+ import os
115
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
116
+
117
+
118
+ # Константы
119
+ NUM_BATCHES = int(1e5)
120
+ BATCH_SIZE = 4
121
+ GRADIENT_ACCUMULATE_EVERY = 4
122
+ LEARNING_RATE = 2e-4
123
+ VALIDATE_EVERY = 100
124
+ GENERATE_EVERY = 500
125
+ PRIME_LENGTH = 100
126
+ GENERATE_LENGTH = 512
127
+ SHOULD_GENERATE = True
128
+ SEQ_LEN = 512
129
+
130
+ # Константы для нейронной памяти
131
+ NEURAL_MEMORY_DEPTH = 2
132
+ NUM_PERSIST_MEM = 4
133
+ NUM_LONGTERM_MEM = 4
134
+ NEURAL_MEM_LAYERS = (2, 4, 6)
135
+ NEURAL_MEM_GATE_ATTN_OUTPUT = False
136
+ NEURAL_MEM_MOMENTUM = True
137
+ NEURAL_MEM_MOMENTUM_ORDER = 1
138
+ NEURAL_MEM_QK_NORM = True
139
+ NEURAL_MEM_MAX_LR = 1e-1
140
+ USE_MEM_ATTENTION_MODEL = False
141
+ WINDOW_SIZE = 32
142
+ NEURAL_MEM_SEGMENT_LEN = 4
143
+ NEURAL_MEM_BATCH_SIZE = 128
144
+ SLIDING_WINDOWS = True
145
+ STORE_ATTN_POOL_CHUNKS = True
146
+ MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
147
+ NEURAL_MEM_WEIGHT_RESIDUAL = True
148
+
149
+ # Инициализация токенизатора
150
+ tokenizer = GPT2TokenizerFast.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
151
+
152
+
153
+ class WikiDatasetPreprocessor:
154
+ def __init__(self, cache_dir: str = 'cache', output_dir: str = 'processed_data'):
155
+ self.cache_dir = Path(cache_dir)
156
+ self.output_dir = Path(output_dir)
157
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
158
+ self.output_dir.mkdir(parents=True, exist_ok=True)
159
+
160
+ # Инициализация токенизатора
161
+ self.tokenizer = GPT2TokenizerFast.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
162
+
163
+ def load_wiki_dataset(self):
164
+ """Загрузка датасета из Hugging Face"""
165
+ print("Loading Wikipedia dataset...")
166
+ dataset = load_dataset("misterkirill/ru-wikipedia", cache_dir=str(self.cache_dir))
167
+ print(f"Dataset loaded. Size: {len(dataset['train'])} articles")
168
+ return dataset
169
+
170
+ def clean_text(self, text: str) -> str:
171
+ """Базовая очистка текста"""
172
+ # Удаляем множественные пробелы и переносы строк
173
+ text = ' '.join(text.split())
174
+ return text
175
+
176
+ # В функции process_and_save уменьшаем размер чанков
177
+ def process_wiki_article(self, text: str) -> List[str]:
178
+ """Обработка одной статьи из википедии"""
179
+ processed_chunks = []
180
+
181
+ clean_text = self.clean_text(text)
182
+ tokens = self.tokenizer.encode(clean_text)
183
+
184
+ # Уменьшаем размер чанка
185
+ chunk_size = 256 # было 512
186
+ stride = 192 # было 384
187
+
188
+ for i in range(0, len(tokens), stride):
189
+ chunk = tokens[i:i + chunk_size]
190
+ if len(chunk) > 50: # уменьшаем минимальную длину чанка
191
+ processed_chunks.append(chunk)
192
+
193
+ return processed_chunks
194
+
195
+ def process_and_save(self, batch_size: int = 1000, test_size: float = 0.1, max_articles: int = 10000):
196
+ """Обработка ограниченного количества статей из датасета и сохранение результатов"""
197
+ dataset = self.load_wiki_dataset()
198
+
199
+ # Ограничиваем размер датасета
200
+ total_articles = min(len(dataset['train']), max_articles)
201
+ print(f"Processing {total_articles} articles out of {len(dataset['train'])}")
202
+
203
+ # Сначала соберем все чанки
204
+ all_chunks = []
205
+
206
+ for i in tqdm(range(0, total_articles, batch_size), desc="Processing articles"):
207
+ batch = dataset['train'][i:i + batch_size]
208
+ for text in batch['text']:
209
+ chunks = self.process_wiki_article(text)
210
+ all_chunks.extend(chunks)
211
+
212
+ # Ограничиваем количество чанков для ускорения обучения
213
+ if len(all_chunks) > 50000: # максимальное количество чанков
214
+ break
215
+
216
+ if len(all_chunks) > 50000:
217
+ break
218
+
219
+ print(f"Total chunks created: {len(all_chunks)}")
220
+
221
+ # Перемешаем чанки
222
+ random.seed(42)
223
+ random.shuffle(all_chunks)
224
+
225
+ # Разделим на train и test
226
+ test_size = int(len(all_chunks) * test_size)
227
+ train_chunks = all_chunks[:-test_size]
228
+ test_chunks = all_chunks[-test_size:]
229
+
230
+ print(f"Saving {len(train_chunks)} training chunks and {len(test_chunks)} test chunks...")
231
+ torch.save({
232
+ 'train': train_chunks,
233
+ 'test': test_chunks
234
+ }, self.output_dir / 'processed_wiki.pt')
235
+
236
+
237
+ class WikiTextDataset(Dataset):
238
+ def __init__(self, chunks: List[List[int]], seq_len: int = 512):
239
+ self.chunks = chunks
240
+ self.seq_len = seq_len
241
+
242
+ def __len__(self):
243
+ return len(self.chunks)
244
+
245
+ def __getitem__(self, idx):
246
+ chunk = self.chunks[idx]
247
+
248
+ # Если чанк короче необходимой длины, дополняем его паддингом
249
+ if len(chunk) < self.seq_len + 1:
250
+ chunk = chunk + [50256] * (self.seq_len + 1 - len(chunk))
251
+ # Если длиннее - обрезаем
252
+ else:
253
+ chunk = chunk[:self.seq_len + 1]
254
+
255
+ return torch.tensor(chunk, device='cuda').long() # Добавляем device='cuda'
256
+
257
+ def create_dataloaders(
258
+ processed_data_path: str,
259
+ batch_size: int = 4,
260
+ seq_len: int = 512,
261
+ train_test_split: float = 0.9
262
+ ) -> tuple:
263
+ """Создание загрузчиков данных для обучения и валидации"""
264
+
265
+ print(f"Loading processed data from {processed_data_path}")
266
+ data = torch.load(processed_data_path)
267
+ train_chunks = data['train']
268
+ test_chunks = data['test']
269
+
270
+ # Создание датасетов
271
+ train_dataset = WikiTextDataset(train_chunks, seq_len)
272
+ test_dataset = WikiTextDataset(test_chunks, seq_len)
273
+
274
+ print(f"Created datasets with {len(train_dataset)} training and {len(test_dataset)} test samples")
275
+
276
+ # Создание загрузчиков данных
277
+ train_loader = DataLoader(
278
+ train_dataset,
279
+ batch_size=batch_size,
280
+ shuffle=True,
281
+ num_workers=0, # Убираем многопоточность для отладки
282
+ pin_memory=False # Отключаем pin_memory, так как данные уже на GPU
283
+ )
284
+
285
+ val_loader = DataLoader(
286
+ test_dataset,
287
+ batch_size=batch_size,
288
+ shuffle=False,
289
+ num_workers=0, # Убираем многопоточность для отладки
290
+ pin_memory=False # Отключаем pin_memory, так как данные уже на GPU
291
+ )
292
+
293
+ return train_loader, val_loader
294
+
295
+ def cycle(loader):
296
+ """Бесконечный итератор по загрузчику данных"""
297
+ while True:
298
+ for data in loader:
299
+ yield data
300
+
301
+ def create_model():
302
+ try:
303
+ if USE_MEM_ATTENTION_MODEL:
304
+ neural_memory_model = MemoryAttention(dim=64)
305
+ else:
306
+ neural_memory_model = MemoryMLP(dim=64, depth=NEURAL_MEMORY_DEPTH)
307
+
308
+ model = MemoryAsContextTransformer(
309
+ num_tokens=len(tokenizer),
310
+ dim=384,
311
+ depth=8,
312
+ segment_len=WINDOW_SIZE,
313
+ num_persist_mem_tokens=NUM_PERSIST_MEM,
314
+ num_longterm_mem_tokens=NUM_LONGTERM_MEM,
315
+ neural_memory_layers=NEURAL_MEM_LAYERS,
316
+ neural_memory_segment_len=NEURAL_MEM_SEGMENT_LEN,
317
+ neural_memory_batch_size=NEURAL_MEM_BATCH_SIZE,
318
+ neural_mem_gate_attn_output=NEURAL_MEM_GATE_ATTN_OUTPUT,
319
+ neural_mem_weight_residual=NEURAL_MEM_WEIGHT_RESIDUAL,
320
+ use_flex_attn=True,
321
+ sliding_window_attn=SLIDING_WINDOWS,
322
+ neural_memory_model=neural_memory_model,
323
+ neural_memory_kwargs=dict(
324
+ dim_head=64,
325
+ heads=4,
326
+ attn_pool_chunks=STORE_ATTN_POOL_CHUNKS,
327
+ qk_rmsnorm=NEURAL_MEM_QK_NORM,
328
+ momentum=NEURAL_MEM_MOMENTUM,
329
+ momentum_order=NEURAL_MEM_MOMENTUM_ORDER,
330
+ default_step_transform_max_lr=NEURAL_MEM_MAX_LR,
331
+ use_accelerated_scan=True,
332
+ per_parameter_lr_modulation=MEMORY_MODEL_PER_LAYER_LEARNED_LR
333
+ )
334
+ ).cuda()
335
+
336
+ # Проверка, что модель на GPU
337
+ assert next(model.parameters()).is_cuda, "Model is not on CUDA"
338
+
339
+ return model
340
+
341
+ except Exception as e:
342
+ print(f"Error creating model: {e}")
343
+ raise e
344
+
345
+ def train_model(model, train_loader, val_loader, num_batches=int(1e4)):
346
+ optim = AdoptAtan2(model.parameters(), lr=2e-4)
347
+
348
+ # Включаем автоматическую очистку кэша CUDA
349
+ torch.cuda.empty_cache()
350
+
351
+ pbar = tqdm(range(num_batches), desc='Training')
352
+ running_loss = 0.0
353
+
354
+ try:
355
+ for i in pbar:
356
+ model.train()
357
+
358
+ total_loss = 0
359
+ # Обучение с градиентным накоплением
360
+ for __ in range(4):
361
+ batch = next(train_loader)
362
+ loss = model(batch, return_loss=True)
363
+ loss = loss / 4 # нормализуем loss при градиентном накоплении
364
+ loss.backward()
365
+ total_loss += loss.item()
366
+
367
+ # Клиппинг градиентов
368
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
369
+ optim.step()
370
+ optim.zero_grad()
371
+
372
+ # Очищаем кэш CUDA каждые 100 итераций
373
+ if i % 100 == 0:
374
+ torch.cuda.empty_cache()
375
+
376
+ avg_loss = total_loss
377
+ running_loss = 0.9 * running_loss + 0.1 * avg_loss if i > 0 else avg_loss
378
+
379
+ pbar.set_postfix({
380
+ 'loss': f'{running_loss:.4f}',
381
+ 'batch_loss': f'{avg_loss:.4f}'
382
+ })
383
+
384
+ # Валидация
385
+ if i % 100 == 0:
386
+ model.eval()
387
+ with torch.no_grad():
388
+ val_batch = next(val_loader)
389
+ val_loss = model(val_batch, return_loss=True)
390
+ pbar.set_postfix({
391
+ 'train_loss': f'{running_loss:.4f}',
392
+ 'val_loss': f'{val_loss.item():.4f}'
393
+ })
394
+
395
+ # Сохранение чекпойнта
396
+ if i % 1000 == 0 and i > 0:
397
+ torch.save({
398
+ 'epoch': i,
399
+ 'model_state_dict': model.state_dict(),
400
+ 'optimizer_state_dict': optim.state_dict(),
401
+ 'loss': running_loss,
402
+ }, f'checkpoint_{i}.pt')
403
+
404
+ except KeyboardInterrupt:
405
+ print("\nTraining interrupted by user")
406
+ except Exception as e:
407
+ print(f"\nTraining stopped due to error: {e}")
408
+ raise e
409
+
410
+ return model
411
+
412
+ def main():
413
+ try:
414
+ if not torch.cuda.is_available():
415
+ raise RuntimeError("CUDA is not available. This code requires GPU.")
416
+
417
+ print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
418
+
419
+ # Параметры
420
+ BATCH_SIZE = 4
421
+ SEQ_LEN = 512
422
+ CACHE_DIR = 'cache'
423
+ PROCESSED_DATA_DIR = 'processed_data'
424
+ NUM_BATCHES = 10000 # уменьшаем количество итераций
425
+
426
+ # Подготовка данных
427
+ preprocessor = WikiDatasetPreprocessor(CACHE_DIR, PROCESSED_DATA_DIR)
428
+
429
+ processed_data_path = Path(PROCESSED_DATA_DIR) / 'processed_wiki.pt'
430
+ if not processed_data_path.exists():
431
+ print("Processing Wikipedia dataset...")
432
+ preprocessor.process_and_save(max_articles=10000) # ограничиваем количество статей
433
+
434
+ # Создание загрузчиков данных
435
+ train_loader, val_loader = create_dataloaders(
436
+ processed_data_path,
437
+ batch_size=BATCH_SIZE,
438
+ seq_len=SEQ_LEN
439
+ )
440
+
441
+ # Создание бесконечных итераторов
442
+ train_loader = cycle(train_loader)
443
+ val_loader = cycle(val_loader)
444
+
445
+ # Создание и обучение модели
446
+ model = create_model()
447
+ model = train_model(model, train_loader, val_loader, num_batches=NUM_BATCHES)
448
+
449
+ # Сохранение финальной модели
450
+ torch.save(model.state_dict(), 'final_model.pt')
451
+
452
+ return model, train_loader, val_loader
453
+
454
+ except Exception as e:
455
+ print(f"Error in main: {e}")
456
+ raise e
457
+
458
+ if __name__ == "__main__":
459
+ # Установка seed для воспроизводимости
460
+ torch.manual_seed(42)
461
+ torch.cuda.manual_seed_all(42)
462
+
463
+ # Включение оптимизаций CUDA
464
+ torch.backends.cudnn.benchmark = True
465
+
466
+ model, train_loader, val_loader = main()
467
+ ```
468
+
469
+ # Training
470
+
471
+ The model was trained on a cleaned subset of Russian Wikipedia articles using the following parameters:
472
+
473
+
474
+ Batch size: 4
475
+ Sequence length: 512
476
+ Learning rate: 2e-4
477
+ Gradient accumulation steps: 4
478
+ Neural memory depth: 2
479
+ Window size: 32
480
+
481
+ ## Train Code
482
+ ```python
483
+ import json
484
+ import os
485
+ import random
486
+ import re
487
+ from pathlib import Path
488
+ from typing import List, Dict
489
+
490
+ import numpy as np
491
+ import torch
492
+ from torch import nn
493
+ from torch.utils.data import DataLoader, Dataset
494
+ from transformers import GPT2TokenizerFast
495
+ from tqdm import tqdm
496
+ from datasets import load_dataset
497
+ from adam_atan2_pytorch import AdoptAtan2
498
+ from titans_pytorch import (
499
+ MemoryAsContextTransformer,
500
+ MemoryMLP,
501
+ MemoryAttention
502
+ )
503
+
504
+ # CUDA memory settings
505
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
506
+
507
+ # Training constants
508
+ NUM_BATCHES = int(1e5)
509
+ BATCH_SIZE = 4
510
+ GRADIENT_ACCUMULATE_EVERY = 4
511
+ LEARNING_RATE = 2e-4
512
+ VALIDATE_EVERY = 100
513
+ GENERATE_EVERY = 500
514
+ PRIME_LENGTH = 100
515
+ GENERATE_LENGTH = 512
516
+ SHOULD_GENERATE = True
517
+ SEQ_LEN = 512
518
+
519
+ # Neural memory constants
520
+ NEURAL_MEMORY_DEPTH = 2
521
+ NUM_PERSIST_MEM = 4
522
+ NUM_LONGTERM_MEM = 4
523
+ NEURAL_MEM_LAYERS = (2, 4, 6)
524
+ NEURAL_MEM_GATE_ATTN_OUTPUT = False
525
+ NEURAL_MEM_MOMENTUM = True
526
+ NEURAL_MEM_MOMENTUM_ORDER = 1
527
+ NEURAL_MEM_QK_NORM = True
528
+ NEURAL_MEM_MAX_LR = 1e-1
529
+ USE_MEM_ATTENTION_MODEL = False
530
+ WINDOW_SIZE = 32
531
+ NEURAL_MEM_SEGMENT_LEN = 4
532
+ NEURAL_MEM_BATCH_SIZE = 128
533
+ SLIDING_WINDOWS = True
534
+ STORE_ATTN_POOL_CHUNKS = True
535
+ MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
536
+ NEURAL_MEM_WEIGHT_RESIDUAL = True
537
+
538
+ # Initialize tokenizer
539
+ tokenizer = GPT2TokenizerFast.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
540
+
541
+
542
+ class WikiDatasetPreprocessor:
543
+ def __init__(self, cache_dir: str = 'cache', output_dir: str = 'processed_data'):
544
+ self.cache_dir = Path(cache_dir)
545
+ self.output_dir = Path(output_dir)
546
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
547
+ self.output_dir.mkdir(parents=True, exist_ok=True)
548
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(
549
+ 'sberbank-ai/rugpt3small_based_on_gpt2'
550
+ )
551
+
552
+ def load_wiki_dataset(self):
553
+ """Загрузка датасета из Hugging Face."""
554
+ print("Loading Wikipedia dataset...")
555
+ dataset = load_dataset(
556
+ "misterkirill/ru-wikipedia",
557
+ cache_dir=str(self.cache_dir)
558
+ )
559
+ print(f"Dataset loaded. Size: {len(dataset['train'])} articles")
560
+ return dataset
561
+
562
+ def clean_text(self, text: str) -> str:
563
+ """Базовая очистка текста."""
564
+ return ' '.join(text.split())
565
+
566
+ def process_wiki_article(self, text: str) -> List[str]:
567
+ """Обработка одной статьи из википедии."""
568
+ processed_chunks = []
569
+ clean_text = self.clean_text(text)
570
+ tokens = self.tokenizer.encode(clean_text)
571
+
572
+ chunk_size = 256
573
+ stride = 192
574
+
575
+ for i in range(0, len(tokens), stride):
576
+ chunk = tokens[i:i + chunk_size]
577
+ if len(chunk) > 50:
578
+ processed_chunks.append(chunk)
579
+
580
+ return processed_chunks
581
+
582
+ def process_and_save(
583
+ self,
584
+ batch_size: int = 1000,
585
+ test_size: float = 0.1,
586
+ max_articles: int = 10000
587
+ ):
588
+ """Обработка статей из датасета и сохранение результатов."""
589
+ dataset = self.load_wiki_dataset()
590
+ total_articles = min(len(dataset['train']), max_articles)
591
+ print(f"Processing {total_articles} articles out of {len(dataset['train'])}")
592
+
593
+ all_chunks = []
594
+ for i in tqdm(range(0, total_articles, batch_size), desc="Processing articles"):
595
+ batch = dataset['train'][i:i + batch_size]
596
+ for text in batch['text']:
597
+ chunks = self.process_wiki_article(text)
598
+ all_chunks.extend(chunks)
599
+
600
+ if len(all_chunks) > 50000:
601
+ break
602
+
603
+ if len(all_chunks) > 50000:
604
+ break
605
+
606
+ print(f"Total chunks created: {len(all_chunks)}")
607
+
608
+ random.seed(42)
609
+ random.shuffle(all_chunks)
610
+
611
+ test_size = int(len(all_chunks) * test_size)
612
+ train_chunks = all_chunks[:-test_size]
613
+ test_chunks = all_chunks[-test_size:]
614
+
615
+ print(f"Saving {len(train_chunks)} training chunks and {len(test_chunks)} test chunks...")
616
+ torch.save(
617
+ {
618
+ 'train': train_chunks,
619
+ 'test': test_chunks
620
+ },
621
+ self.output_dir / 'processed_wiki.pt'
622
+ )
623
+
624
+
625
+ class WikiTextDataset(Dataset):
626
+ def __init__(self, chunks: List[List[int]], seq_len: int = 512):
627
+ self.chunks = chunks
628
+ self.seq_len = seq_len
629
+
630
+ def __len__(self):
631
+ return len(self.chunks)
632
+
633
+ def __getitem__(self, idx):
634
+ chunk = self.chunks[idx]
635
+ if len(chunk) < self.seq_len + 1:
636
+ chunk = chunk + [50256] * (self.seq_len + 1 - len(chunk))
637
+ else:
638
+ chunk = chunk[:self.seq_len + 1]
639
+ return torch.tensor(chunk, device='cuda').long()
640
+
641
+
642
+ def create_dataloaders(
643
+ processed_data_path: str,
644
+ batch_size: int = 4,
645
+ seq_len: int = 512,
646
+ train_test_split: float = 0.9
647
+ ) -> tuple:
648
+ """Создание загрузчиков данных для обучения и валидации."""
649
+ print(f"Loading processed data from {processed_data_path}")
650
+ data = torch.load(processed_data_path)
651
+ train_chunks = data['train']
652
+ test_chunks = data['test']
653
+
654
+ train_dataset = WikiTextDataset(train_chunks, seq_len)
655
+ test_dataset = WikiTextDataset(test_chunks, seq_len)
656
+
657
+ print(f"Created datasets with {len(train_dataset)} training and "
658
+ f"{len(test_dataset)} test samples")
659
+
660
+ train_loader = DataLoader(
661
+ train_dataset,
662
+ batch_size=batch_size,
663
+ shuffle=True,
664
+ num_workers=0,
665
+ pin_memory=False
666
+ )
667
+
668
+ val_loader = DataLoader(
669
+ test_dataset,
670
+ batch_size=batch_size,
671
+ shuffle=False,
672
+ num_workers=0,
673
+ pin_memory=False
674
+ )
675
+
676
+ return train_loader, val_loader
677
+
678
+
679
+ def cycle(loader):
680
+ """Бесконечный итератор по загрузчику данных."""
681
+ while True:
682
+ for data in loader:
683
+ yield data
684
+
685
+
686
+ def create_model():
687
+ """Создание модели нейронной сети."""
688
+ try:
689
+ if USE_MEM_ATTENTION_MODEL:
690
+ neural_memory_model = MemoryAttention(dim=64)
691
+ else:
692
+ neural_memory_model = MemoryMLP(dim=64, depth=NEURAL_MEMORY_DEPTH)
693
+
694
+ model = MemoryAsContextTransformer(
695
+ num_tokens=len(tokenizer),
696
+ dim=384,
697
+ depth=8,
698
+ segment_len=WINDOW_SIZE,
699
+ num_persist_mem_tokens=NUM_PERSIST_MEM,
700
+ num_longterm_mem_tokens=NUM_LONGTERM_MEM,
701
+ neural_memory_layers=NEURAL_MEM_LAYERS,
702
+ neural_memory_segment_len=NEURAL_MEM_SEGMENT_LEN,
703
+ neural_memory_batch_size=NEURAL_MEM_BATCH_SIZE,
704
+ neural_mem_gate_attn_output=NEURAL_MEM_GATE_ATTN_OUTPUT,
705
+ neural_mem_weight_residual=NEURAL_MEM_WEIGHT_RESIDUAL,
706
+ use_flex_attn=True,
707
+ sliding_window_attn=SLIDING_WINDOWS,
708
+ neural_memory_model=neural_memory_model,
709
+ neural_memory_kwargs=dict(
710
+ dim_head=64,
711
+ heads=4,
712
+ attn_pool_chunks=STORE_ATTN_POOL_CHUNKS,
713
+ qk_rmsnorm=NEURAL_MEM_QK_NORM,
714
+ momentum=NEURAL_MEM_MOMENTUM,
715
+ momentum_order=NEURAL_MEM_MOMENTUM_ORDER,
716
+ default_step_transform_max_lr=NEURAL_MEM_MAX_LR,
717
+ use_accelerated_scan=True,
718
+ per_parameter_lr_modulation=MEMORY_MODEL_PER_LAYER_LEARNED_LR
719
+ )
720
+ ).cuda()
721
+
722
+ assert next(model.parameters()).is_cuda, "Model is not on CUDA"
723
+ return model
724
+
725
+ except Exception as e:
726
+ print(f"Error creating model: {e}")
727
+ raise e
728
+
729
+
730
+ def train_model(model, train_loader, val_loader, num_batches=int(1e4)):
731
+ """Обучение модели."""
732
+ optim = AdoptAtan2(model.parameters(), lr=2e-4)
733
+ torch.cuda.empty_cache()
734
+ pbar = tqdm(range(num_batches), desc='Training')
735
+ running_loss = 0.0
736
+
737
+ try:
738
+ for i in pbar:
739
+ model.train()
740
+ total_loss = 0
741
+
742
+ for __ in range(4):
743
+ batch = next(train_loader)
744
+ loss = model(batch, return_loss=True)
745
+ loss = loss / 4
746
+ loss.backward()
747
+ total_loss += loss.item()
748
+
749
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
750
+ optim.step()
751
+ optim.zero_grad()
752
+
753
+ if i % 100 == 0:
754
+ torch.cuda.empty_cache()
755
+
756
+ avg_loss = total_loss
757
+ running_loss = 0.9 * running_loss + 0.1 * avg_loss if i > 0 else avg_loss
758
+
759
+ pbar.set_postfix({
760
+ 'loss': f'{running_loss:.4f}',
761
+ 'batch_loss': f'{avg_loss:.4f}'
762
+ })
763
+
764
+ if i % 100 == 0:
765
+ model.eval()
766
+ with torch.no_grad():
767
+ val_batch = next(val_loader)
768
+ val_loss = model(val_batch, return_loss=True)
769
+ pbar.set_postfix({
770
+ 'train_loss': f'{running_loss:.4f}',
771
+ 'val_loss': f'{val_loss.item():.4f}'
772
+ })
773
+
774
+ if i % 1000 == 0 and i > 0:
775
+ torch.save({
776
+ 'epoch': i,
777
+ 'model_state_dict': model.state_dict(),
778
+ 'optimizer_state_dict': optim.state_dict(),
779
+ 'loss': running_loss,
780
+ }, f'checkpoint_{i}.pt')
781
+
782
+ except KeyboardInterrupt:
783
+ print("\nTraining interrupted by user")
784
+ except Exception as e:
785
+ print(f"\nTraining stopped due to error: {e}")
786
+ raise e
787
+
788
+ return model
789
+
790
+
791
+ def main():
792
+ """Основная функция программы."""
793
+ try:
794
+ if not torch.cuda.is_available():
795
+ raise RuntimeError("CUDA is not available. This code requires GPU.")
796
+
797
+ print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
798
+
799
+ BATCH_SIZE = 4
800
+ SEQ_LEN = 512
801
+ CACHE_DIR = 'cache'
802
+ PROCESSED_DATA_DIR = 'processed_data'
803
+ NUM_BATCHES = 10000
804
+
805
+ preprocessor = WikiDatasetPreprocessor(CACHE_DIR, PROCESSED_DATA_DIR)
806
+ processed_data_path = Path(PROCESSED_DATA_DIR) / 'processed_wiki.pt'
807
+
808
+ if not processed_data_path.exists():
809
+ print("Processing Wikipedia dataset...")
810
+ preprocessor.process_and_save(max_articles=10000)
811
+
812
+ train_loader, val_loader = create_dataloaders(
813
+ processed_data_path,
814
+ batch_size=BATCH_SIZE,
815
+ seq_len=SEQ_LEN
816
+ )
817
+
818
+ train_loader = cycle(train_loader)
819
+ val_loader = cycle(val_loader)
820
+
821
+ model = create_model()
822
+ model = train_model(model, train_loader, val_loader, num_batches=NUM_BATCHES)
823
+
824
+ torch.save(model.state_dict(), 'final_model.pt')
825
+ return model, train_loader, val_loader
826
+
827
+ except Exception as e:
828
+ print(f"Error in main: {e}")
829
+ raise e
830
+
831
+
832
+ if __name__ == "__main__":
833
+ torch.manual_seed(42)
834
+ torch.cuda.manual_seed_all(42)
835
+ torch.backends.cudnn.benchmark = True
836
+ model, train_loader, val_loader = main()
837
+ ```
838
+
839
+ # License
840
+
841
+ This project is licensed under the MIT License. See LICENSE file for details.
842
+
843
+
844
+ # Citation
845
+
846
+ If you use this model in your research, please cite:
847
+ ```bibtex
848
+ @software{neural_memory_model,
849
+ title = {Neural Memory Model for Russian Text Generation},
850
+ year = {2024},
851
+ url = {https://huggingface.co/Grpp/memory-transformer-ru}
852
+ }
853
+ ```