Grpp commited on
Commit
17ebf79
·
verified ·
1 Parent(s): 9fccb26

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +162 -304
README.md CHANGED
@@ -79,17 +79,28 @@ The repository includes complete training and inference code. Key components:
79
  ## Example Code
80
  ```python
81
  import os
82
- import re
83
- import json
84
- import random
85
- from tqdm import tqdm
86
- import numpy as np
87
  from pathlib import Path
 
88
 
89
  import torch
90
  from torch import nn
91
- from torch.utils.data import DataLoader, Dataset
92
- from transformers import GPT2TokenizerFast
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  from adam_atan2_pytorch import AdoptAtan2
94
 
95
  from titans_pytorch import (
@@ -98,24 +109,17 @@ from titans_pytorch import (
98
  MemoryAttention
99
  )
100
 
101
- import os
102
- import json
103
- import random
104
- from pathlib import Path
105
- from typing import List, Dict
106
- import numpy as np
107
- from tqdm import tqdm
108
- from datasets import load_dataset
109
- import torch
110
- from torch.utils.data import Dataset, DataLoader
111
- from transformers import GPT2TokenizerFast
112
 
113
- # Добавляем настройки для управления памятью CUDA
114
- import os
115
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
116
 
117
-
118
  # Константы
 
119
  NUM_BATCHES = int(1e5)
120
  BATCH_SIZE = 4
121
  GRADIENT_ACCUMULATE_EVERY = 4
@@ -146,173 +150,59 @@ STORE_ATTN_POOL_CHUNKS = True
146
  MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
147
  NEURAL_MEM_WEIGHT_RESIDUAL = True
148
 
149
- # Инициализация токенизатора
150
- tokenizer = GPT2TokenizerFast.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
151
-
152
-
153
- class WikiDatasetPreprocessor:
154
- def __init__(self, cache_dir: str = 'cache', output_dir: str = 'processed_data'):
155
- self.cache_dir = Path(cache_dir)
156
- self.output_dir = Path(output_dir)
157
- self.cache_dir.mkdir(parents=True, exist_ok=True)
158
- self.output_dir.mkdir(parents=True, exist_ok=True)
159
-
160
- # Инициализация токенизатора
161
- self.tokenizer = GPT2TokenizerFast.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
162
-
163
- def load_wiki_dataset(self):
164
- """Загрузка датасета из Hugging Face"""
165
- print("Loading Wikipedia dataset...")
166
- dataset = load_dataset("misterkirill/ru-wikipedia", cache_dir=str(self.cache_dir))
167
- print(f"Dataset loaded. Size: {len(dataset['train'])} articles")
168
- return dataset
169
-
170
- def clean_text(self, text: str) -> str:
171
- """Базовая очистка текста"""
172
- # Удаляем множественные пробелы и переносы строк
173
- text = ' '.join(text.split())
174
- return text
175
-
176
- # В функции process_and_save уменьшаем размер чанков
177
- def process_wiki_article(self, text: str) -> List[str]:
178
- """Обработка одной статьи из википедии"""
179
- processed_chunks = []
180
-
181
- clean_text = self.clean_text(text)
182
- tokens = self.tokenizer.encode(clean_text)
183
-
184
- # Уменьшаем размер чанка
185
- chunk_size = 256 # было 512
186
- stride = 192 # было 384
187
-
188
- for i in range(0, len(tokens), stride):
189
- chunk = tokens[i:i + chunk_size]
190
- if len(chunk) > 50: # уменьшаем минимальную длину чанка
191
- processed_chunks.append(chunk)
192
-
193
- return processed_chunks
194
-
195
- def process_and_save(self, batch_size: int = 1000, test_size: float = 0.1, max_articles: int = 10000):
196
- """Обработка ограниченного количества статей из датасета и сохранение результатов"""
197
- dataset = self.load_wiki_dataset()
198
 
199
- # Ограничиваем размер датасета
200
- total_articles = min(len(dataset['train']), max_articles)
201
- print(f"Processing {total_articles} articles out of {len(dataset['train'])}")
202
-
203
- # Сначала соберем все чанки
204
- all_chunks = []
205
-
206
- for i in tqdm(range(0, total_articles, batch_size), desc="Processing articles"):
207
- batch = dataset['train'][i:i + batch_size]
208
- for text in batch['text']:
209
- chunks = self.process_wiki_article(text)
210
- all_chunks.extend(chunks)
211
-
212
- # Ограничиваем количество чанков для ускорения обучения
213
- if len(all_chunks) > 50000: # максимальное количество чанков
214
- break
215
-
216
- if len(all_chunks) > 50000:
217
- break
218
-
219
- print(f"Total chunks created: {len(all_chunks)}")
220
-
221
- # Перемешаем чанки
222
- random.seed(42)
223
- random.shuffle(all_chunks)
224
-
225
- # Разделим на train и test
226
- test_size = int(len(all_chunks) * test_size)
227
- train_chunks = all_chunks[:-test_size]
228
- test_chunks = all_chunks[-test_size:]
229
-
230
- print(f"Saving {len(train_chunks)} training chunks and {len(test_chunks)} test chunks...")
231
- torch.save({
232
- 'train': train_chunks,
233
- 'test': test_chunks
234
- }, self.output_dir / 'processed_wiki.pt')
235
 
236
 
237
- class WikiTextDataset(Dataset):
238
- def __init__(self, chunks: List[List[int]], seq_len: int = 512):
239
- self.chunks = chunks
240
- self.seq_len = seq_len
241
 
242
- def __len__(self):
243
- return len(self.chunks)
244
-
245
- def __getitem__(self, idx):
246
- chunk = self.chunks[idx]
247
 
248
- # Если чанк короче необходимой длины, дополняем его паддингом
249
- if len(chunk) < self.seq_len + 1:
250
- chunk = chunk + [50256] * (self.seq_len + 1 - len(chunk))
251
- # Если длиннее - обрезаем
252
- else:
253
- chunk = chunk[:self.seq_len + 1]
254
-
255
- return torch.tensor(chunk, device='cuda').long() # Добавляем device='cuda'
256
-
257
- def create_dataloaders(
258
- processed_data_path: str,
259
- batch_size: int = 4,
260
- seq_len: int = 512,
261
- train_test_split: float = 0.9
262
- ) -> tuple:
263
- """Создание загрузчиков данных для обучения и валидации"""
264
-
265
- print(f"Loading processed data from {processed_data_path}")
266
- data = torch.load(processed_data_path)
267
- train_chunks = data['train']
268
- test_chunks = data['test']
269
-
270
- # Создание датасетов
271
- train_dataset = WikiTextDataset(train_chunks, seq_len)
272
- test_dataset = WikiTextDataset(test_chunks, seq_len)
273
-
274
- print(f"Created datasets with {len(train_dataset)} training and {len(test_dataset)} test samples")
275
-
276
- # Создание загрузчиков данных
277
- train_loader = DataLoader(
278
- train_dataset,
279
- batch_size=batch_size,
280
- shuffle=True,
281
- num_workers=0, # Убираем многопоточность для отладки
282
- pin_memory=False # Отключаем pin_memory, так как данные уже на GPU
283
- )
284
-
285
- val_loader = DataLoader(
286
- test_dataset,
287
- batch_size=batch_size,
288
- shuffle=False,
289
- num_workers=0, # Убираем многопоточность для отладки
290
- pin_memory=False # Отключаем pin_memory, так как данные уже на GPU
291
- )
292
-
293
- return train_loader, val_loader
294
-
295
- def cycle(loader):
296
- """Бесконечный итератор по загрузчику данных"""
297
- while True:
298
- for data in loader:
299
- yield data
300
 
301
- def create_model():
302
- try:
303
- if USE_MEM_ATTENTION_MODEL:
304
- neural_memory_model = MemoryAttention(dim=64)
305
- else:
306
- neural_memory_model = MemoryMLP(dim=64, depth=NEURAL_MEMORY_DEPTH)
307
-
308
- model = MemoryAsContextTransformer(
309
- num_tokens=len(tokenizer),
310
- dim=384,
311
- depth=8,
312
- segment_len=WINDOW_SIZE,
313
- num_persist_mem_tokens=NUM_PERSIST_MEM,
314
- num_longterm_mem_tokens=NUM_LONGTERM_MEM,
315
- neural_memory_layers=NEURAL_MEM_LAYERS,
316
  neural_memory_segment_len=NEURAL_MEM_SEGMENT_LEN,
317
  neural_memory_batch_size=NEURAL_MEM_BATCH_SIZE,
318
  neural_mem_gate_attn_output=NEURAL_MEM_GATE_ATTN_OUTPUT,
@@ -331,139 +221,107 @@ def create_model():
331
  use_accelerated_scan=True,
332
  per_parameter_lr_modulation=MEMORY_MODEL_PER_LAYER_LEARNED_LR
333
  )
334
- ).cuda()
335
 
336
- # Проверка, что модель на GPU
337
- assert next(model.parameters()).is_cuda, "Model is not on CUDA"
 
 
 
 
 
 
 
 
338
 
339
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
- except Exception as e:
342
- print(f"Error creating model: {e}")
343
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
- def train_model(model, train_loader, val_loader, num_batches=int(1e4)):
346
- optim = AdoptAtan2(model.parameters(), lr=2e-4)
347
-
348
- # Включаем автоматическую очистку кэша CUDA
349
- torch.cuda.empty_cache()
 
 
 
 
 
 
350
 
351
- pbar = tqdm(range(num_batches), desc='Training')
352
- running_loss = 0.0
353
 
354
- try:
355
- for i in pbar:
356
- model.train()
357
-
358
- total_loss = 0
359
- # Обучение с градиентным накоплением
360
- for __ in range(4):
361
- batch = next(train_loader)
362
- loss = model(batch, return_loss=True)
363
- loss = loss / 4 # нормализуем loss при градиентном накоплении
364
- loss.backward()
365
- total_loss += loss.item()
366
-
367
- # Клиппинг градиентов
368
- torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
369
- optim.step()
370
- optim.zero_grad()
371
-
372
- # Очищаем кэш CUDA каждые 100 итераций
373
- if i % 100 == 0:
374
- torch.cuda.empty_cache()
375
-
376
- avg_loss = total_loss
377
- running_loss = 0.9 * running_loss + 0.1 * avg_loss if i > 0 else avg_loss
378
-
379
- pbar.set_postfix({
380
- 'loss': f'{running_loss:.4f}',
381
- 'batch_loss': f'{avg_loss:.4f}'
382
- })
383
-
384
- # Валидация
385
- if i % 100 == 0:
386
- model.eval()
387
- with torch.no_grad():
388
- val_batch = next(val_loader)
389
- val_loss = model(val_batch, return_loss=True)
390
- pbar.set_postfix({
391
- 'train_loss': f'{running_loss:.4f}',
392
- 'val_loss': f'{val_loss.item():.4f}'
393
- })
394
-
395
- # Сохранение че��пойнта
396
- if i % 1000 == 0 and i > 0:
397
- torch.save({
398
- 'epoch': i,
399
- 'model_state_dict': model.state_dict(),
400
- 'optimizer_state_dict': optim.state_dict(),
401
- 'loss': running_loss,
402
- }, f'checkpoint_{i}.pt')
403
-
404
- except KeyboardInterrupt:
405
- print("\nTraining interrupted by user")
406
- except Exception as e:
407
- print(f"\nTraining stopped due to error: {e}")
408
- raise e
409
-
410
- return model
411
-
412
- def main():
413
- try:
414
- if not torch.cuda.is_available():
415
- raise RuntimeError("CUDA is not available. This code requires GPU.")
416
-
417
- print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
418
-
419
- # Параметры
420
- BATCH_SIZE = 4
421
- SEQ_LEN = 512
422
- CACHE_DIR = 'cache'
423
- PROCESSED_DATA_DIR = 'processed_data'
424
- NUM_BATCHES = 10000 # уменьшаем количество итераций
425
-
426
- # Подготовка данных
427
- preprocessor = WikiDatasetPreprocessor(CACHE_DIR, PROCESSED_DATA_DIR)
428
-
429
- processed_data_path = Path(PROCESSED_DATA_DIR) / 'processed_wiki.pt'
430
- if not processed_data_path.exists():
431
- print("Processing Wikipedia dataset...")
432
- preprocessor.process_and_save(max_articles=10000) # ограничиваем количество статей
433
-
434
- # Создание загрузчиков данных
435
- train_loader, val_loader = create_dataloaders(
436
- processed_data_path,
437
- batch_size=BATCH_SIZE,
438
- seq_len=SEQ_LEN
439
  )
440
-
441
- # Создание бесконечных итераторов
442
- train_loader = cycle(train_loader)
443
- val_loader = cycle(val_loader)
444
-
445
- # Создание и обучение модели
446
- model = create_model()
447
- model = train_model(model, train_loader, val_loader, num_batches=NUM_BATCHES)
448
-
449
- # Сохранение финальной модели
450
- torch.save(model.state_dict(), 'final_model.pt')
451
-
452
- return model, train_loader, val_loader
453
-
454
- except Exception as e:
455
- print(f"Error in main: {e}")
456
- raise e
457
 
458
  if __name__ == "__main__":
459
- # Установка seed для воспроизводимости
460
  torch.manual_seed(42)
461
  torch.cuda.manual_seed_all(42)
462
-
463
- # Включение оптимизаций CUDA
464
- torch.backends.cudnn.benchmark = True
465
-
466
- model, train_loader, val_loader = main()
 
 
 
 
467
  ```
468
 
469
  # Training
 
79
  ## Example Code
80
  ```python
81
  import os
82
+ import warnings
 
 
 
 
83
  from pathlib import Path
84
+ from typing import List, Dict, Optional, Tuple
85
 
86
  import torch
87
  from torch import nn
88
+ from torch.utils.data import Dataset, DataLoader
89
+ from transformers import (
90
+ GPT2TokenizerFast,
91
+ PreTrainedModel,
92
+ PreTrainedTokenizer,
93
+ AutoConfig,
94
+ AutoModelForCausalLM,
95
+ AutoTokenizer,
96
+ PretrainedConfig,
97
+ GenerationMixin,
98
+ pipeline
99
+ )
100
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
101
+ from huggingface_hub import HfApi, login
102
+ from datasets import load_dataset
103
+ from tqdm import tqdm
104
  from adam_atan2_pytorch import AdoptAtan2
105
 
106
  from titans_pytorch import (
 
109
  MemoryAttention
110
  )
111
 
112
+ # Отключаем предупреждения
113
+ warnings.filterwarnings("ignore", category=UserWarning)
114
+ torch._dynamo.config.suppress_errors = True
115
+ torch._dynamo.config.cache_size_limit = 100000
116
+ torch._dynamo.config.disable = True
 
 
 
 
 
 
117
 
118
+ # Настройки CUDA
 
119
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
120
 
 
121
  # Константы
122
+ repo_id = 'Grpp/memory-transformer-ru'
123
  NUM_BATCHES = int(1e5)
124
  BATCH_SIZE = 4
125
  GRADIENT_ACCUMULATE_EVERY = 4
 
150
  MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
151
  NEURAL_MEM_WEIGHT_RESIDUAL = True
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ class MemoryTransformerConfig(PretrainedConfig):
155
+ model_type = "memory_transformer"
156
+
157
+ def __init__(
158
+ self,
159
+ vocab_size=50257,
160
+ dim=384,
161
+ depth=8,
162
+ segment_len=32,
163
+ num_persist_mem=4,
164
+ num_longterm_mem=4,
165
+ neural_mem_layers=(2, 4, 6),
166
+ pad_token_id=0,
167
+ bos_token_id=1,
168
+ eos_token_id=2,
169
+ **kwargs
170
+ ):
171
+ self.vocab_size = vocab_size
172
+ self.dim = dim
173
+ self.depth = depth
174
+ self.segment_len = segment_len
175
+ self.num_persist_mem = num_persist_mem
176
+ self.num_longterm_mem = num_longterm_mem
177
+ self.neural_mem_layers = neural_mem_layers
178
+ super().__init__(
179
+ pad_token_id=pad_token_id,
180
+ bos_token_id=bos_token_id,
181
+ eos_token_id=eos_token_id,
182
+ **kwargs
183
+ )
 
 
 
 
 
 
184
 
185
 
186
+ class MemoryTransformerForCausalLM(PreTrainedModel, GenerationMixin):
187
+ config_class = MemoryTransformerConfig
188
+ supports_gradient_checkpointing = True
 
189
 
190
+ def __init__(self, config):
191
+ super().__init__(config)
 
 
 
192
 
193
+ neural_memory_model = (
194
+ MemoryAttention(dim=64) if USE_MEM_ATTENTION_MODEL
195
+ else MemoryMLP(dim=64, depth=NEURAL_MEMORY_DEPTH)
196
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ self.transformer = MemoryAsContextTransformer(
199
+ num_tokens=config.vocab_size,
200
+ dim=config.dim,
201
+ depth=config.depth,
202
+ segment_len=config.segment_len,
203
+ num_persist_mem_tokens=config.num_persist_mem,
204
+ num_longterm_mem_tokens=config.num_longterm_mem,
205
+ neural_memory_layers=config.neural_mem_layers,
 
 
 
 
 
 
 
206
  neural_memory_segment_len=NEURAL_MEM_SEGMENT_LEN,
207
  neural_memory_batch_size=NEURAL_MEM_BATCH_SIZE,
208
  neural_mem_gate_attn_output=NEURAL_MEM_GATE_ATTN_OUTPUT,
 
221
  use_accelerated_scan=True,
222
  per_parameter_lr_modulation=MEMORY_MODEL_PER_LAYER_LEARNED_LR
223
  )
224
+ )
225
 
226
+ def forward(
227
+ self,
228
+ input_ids: Optional[torch.LongTensor] = None,
229
+ attention_mask: Optional[torch.FloatTensor] = None,
230
+ labels: Optional[torch.LongTensor] = None,
231
+ return_dict: Optional[bool] = None,
232
+ **kwargs
233
+ ):
234
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
235
+ outputs = self.transformer(input_ids)
236
 
237
+ if labels is not None:
238
+ loss = self.transformer(input_ids, return_loss=True)
239
+ return CausalLMOutputWithCrossAttentions(
240
+ loss=loss,
241
+ logits=outputs,
242
+ past_key_values=None,
243
+ hidden_states=None,
244
+ attentions=None,
245
+ cross_attentions=None
246
+ )
247
+
248
+ return CausalLMOutputWithCrossAttentions(
249
+ loss=None,
250
+ logits=outputs,
251
+ past_key_values=None,
252
+ hidden_states=None,
253
+ attentions=None,
254
+ cross_attentions=None
255
+ )
256
 
257
+ def prepare_inputs_for_generation(
258
+ self,
259
+ input_ids,
260
+ past=None,
261
+ attention_mask=None,
262
+ **kwargs
263
+ ):
264
+ if past:
265
+ input_ids = input_ids[:, -1].unsqueeze(-1)
266
+
267
+ return {
268
+ "input_ids": input_ids,
269
+ "past_key_values": past,
270
+ "attention_mask": attention_mask,
271
+ }
272
+
273
+ @property
274
+ def device(self):
275
+ return next(self.parameters()).device
276
 
277
+
278
+ def setup_custom_model():
279
+ """Регистрация кастомной модели"""
280
+ AutoConfig.register("memory_transformer", MemoryTransformerConfig)
281
+ AutoModelForCausalLM.register(MemoryTransformerConfig, MemoryTransformerForCausalLM)
282
+
283
+
284
+ def generate_example(model, tokenizer, text, max_length=100):
285
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
286
+ model = model.to(device)
287
+ model.eval()
288
 
289
+ input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
290
+ attention_mask = torch.ones_like(input_ids, device=device)
291
 
292
+ print(f"Model device: {next(model.parameters()).device}")
293
+ print(f"Input device: {input_ids.device}")
294
+
295
+ with torch.no_grad():
296
+ outputs = model.generate(
297
+ input_ids=input_ids,
298
+ attention_mask=attention_mask,
299
+ max_length=max_length,
300
+ num_return_sequences=1,
301
+ no_repeat_ngram_size=2,
302
+ do_sample=True,
303
+ top_k=50,
304
+ top_p=0.95,
305
+ temperature=0.7,
306
+ pad_token_id=tokenizer.pad_token_id,
307
+ eos_token_id=tokenizer.eos_token_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  )
309
+
310
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
311
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  if __name__ == "__main__":
 
314
  torch.manual_seed(42)
315
  torch.cuda.manual_seed_all(42)
316
+
317
+ setup_custom_model()
318
+ config = AutoConfig.from_pretrained(repo_id)
319
+ model = AutoModelForCausalLM.from_pretrained(repo_id)
320
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
321
+
322
+ test_text = "Московский кремль является"
323
+ generated_text = generate_example(model, tokenizer, test_text)
324
+ print(generated_text)
325
  ```
326
 
327
  # Training