Grpp commited on
Commit
1d200d1
·
verified ·
1 Parent(s): 17ebf79

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +196 -0
README.md CHANGED
@@ -324,6 +324,202 @@ if __name__ == "__main__":
324
  print(generated_text)
325
  ```
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  # Training
328
 
329
  The model was trained on a cleaned subset of Russian Wikipedia articles using the following parameters:
 
324
  print(generated_text)
325
  ```
326
 
327
+
328
+ ## Finetine Code
329
+
330
+ ```python
331
+ import os
332
+ import torch
333
+ from pathlib import Path
334
+ from torch.utils.data import DataLoader
335
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
336
+ from tqdm import tqdm
337
+ from adam_atan2_pytorch import AdoptAtan2
338
+
339
+ # Импортируем классы из кода обучения
340
+ from run_train_pep8 import (
341
+ WikiDatasetPreprocessor,
342
+ WikiTextDataset,
343
+ create_dataloaders,
344
+ cycle
345
+ ) # From Train Code
346
+
347
+ from test_load import setup_custom_model # From Example Code
348
+
349
+ # Настройки CUDA
350
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
351
+
352
+ # Константы для файнтьюнинга
353
+ BATCH_SIZE = 2
354
+ GRADIENT_ACCUMULATE_EVERY = 2
355
+ LEARNING_RATE = 1e-5
356
+ NUM_EPOCHS = 3
357
+ STEPS_PER_EPOCH = 1000 # Количество шагов на эпоху
358
+ SEQ_LEN = 256
359
+ PROCESSED_DATA_DIR = 'processed_data'
360
+ CACHE_DIR = 'cache'
361
+ REPO_ID = 'Grpp/memory-transformer-ru'
362
+
363
+ def finetune_model(
364
+ model,
365
+ train_loader,
366
+ val_loader,
367
+ num_epochs,
368
+ device,
369
+ save_path='finetuned_model'
370
+ ):
371
+ """Файнтьюнинг модели."""
372
+
373
+ model = model.to(device)
374
+ optimizer = AdoptAtan2(model.parameters(), lr=LEARNING_RATE)
375
+
376
+ best_val_loss = float('inf')
377
+
378
+ for epoch in range(num_epochs):
379
+ model.train()
380
+ total_train_loss = 0
381
+ train_steps = 0
382
+
383
+ # Прогресс-бар для фиксированного количества шагов
384
+ train_pbar = tqdm(range(STEPS_PER_EPOCH),
385
+ desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
386
+
387
+ for step in train_pbar:
388
+ total_loss = 0
389
+
390
+ # Градиентное накопление
391
+ for _ in range(GRADIENT_ACCUMULATE_EVERY):
392
+ batch = next(train_loader)
393
+ batch = batch.to(device)
394
+
395
+ # Получаем входные данные и метки
396
+ inputs = batch[:, :-1]
397
+ labels = batch[:, 1:]
398
+
399
+ # Прямой проход
400
+ outputs = model(input_ids=inputs, labels=labels)
401
+ loss = outputs.loss / GRADIENT_ACCUMULATE_EVERY
402
+
403
+ # Обратное распространение
404
+ loss.backward()
405
+ total_loss += loss.item()
406
+
407
+ # Обновление параметров
408
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
409
+ optimizer.step()
410
+ optimizer.zero_grad()
411
+
412
+ total_train_loss += total_loss
413
+ train_steps += 1
414
+
415
+ # Обновление прогресс-бара
416
+ train_pbar.set_postfix({
417
+ 'loss': f'{total_loss:.4f}',
418
+ 'avg_loss': f'{total_train_loss/train_steps:.4f}'
419
+ })
420
+
421
+ # Валидация каждые 100 шагов
422
+ if step % 100 == 0:
423
+ model.eval()
424
+ val_loss = 0
425
+ val_steps = 0
426
+
427
+ with torch.no_grad():
428
+ for _ in range(10): # Ограничиваем количество валидационных шагов
429
+ val_batch = next(val_loader)
430
+ val_batch = val_batch.to(device)
431
+
432
+ val_inputs = val_batch[:, :-1]
433
+ val_labels = val_batch[:, 1:]
434
+
435
+ val_outputs = model(input_ids=val_inputs, labels=val_labels)
436
+ val_loss += val_outputs.loss.item()
437
+ val_steps += 1
438
+
439
+ avg_val_loss = val_loss / val_steps
440
+
441
+ print(f"\nValidation loss: {avg_val_loss:.4f}")
442
+
443
+ # Сохраняем лучшую модель
444
+ if avg_val_loss < best_val_loss:
445
+ best_val_loss = avg_val_loss
446
+ torch.save({
447
+ 'epoch': epoch,
448
+ 'model_state_dict': model.state_dict(),
449
+ 'optimizer_state_dict': optimizer.state_dict(),
450
+ 'loss': best_val_loss,
451
+ }, f'{save_path}_best.pt')
452
+
453
+ model.train()
454
+
455
+ # Сохраняем чекпойнт после каждой эпохи
456
+ torch.save({
457
+ 'epoch': epoch,
458
+ 'model_state_dict': model.state_dict(),
459
+ 'optimizer_state_dict': optimizer.state_dict(),
460
+ 'loss': total_train_loss / train_steps,
461
+ }, f'{save_path}_epoch_{epoch}.pt')
462
+
463
+ print(f"\nEpoch {epoch+1} completed. Average loss: {total_train_loss/train_steps:.4f}")
464
+
465
+ return model
466
+
467
+ def main():
468
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
469
+ print(f"Using device: {device}")
470
+
471
+ # Загружаем и подготавливаем данные
472
+ processed_data_path = Path(PROCESSED_DATA_DIR) / 'processed_wiki.pt'
473
+
474
+ if not processed_data_path.exists():
475
+ print("Processing dataset...")
476
+ preprocessor = WikiDatasetPreprocessor(CACHE_DIR, PROCESSED_DATA_DIR)
477
+ preprocessor.process_and_save(max_articles=10000)
478
+
479
+ print("Creating dataloaders...")
480
+ train_loader, val_loader = create_dataloaders(
481
+ processed_data_path,
482
+ batch_size=BATCH_SIZE,
483
+ seq_len=SEQ_LEN
484
+ )
485
+
486
+ train_loader = cycle(train_loader)
487
+ val_loader = cycle(val_loader)
488
+
489
+ # Загружаем предобученную модель
490
+ print("Loading pretrained model...")
491
+ setup_custom_model()
492
+ config = AutoConfig.from_pretrained(REPO_ID)
493
+ model = AutoModelForCausalLM.from_pretrained(REPO_ID)
494
+
495
+ print("Starting finetuning...")
496
+ # Файнтьюним модель
497
+ model = finetune_model(
498
+ model,
499
+ train_loader,
500
+ val_loader,
501
+ NUM_EPOCHS,
502
+ device
503
+ )
504
+
505
+ # Сохраняем финальную версию модели
506
+ print("Saving final model...")
507
+ model.save_pretrained('final_finetuned_model')
508
+
509
+ return model
510
+
511
+ if __name__ == "__main__":
512
+ torch.manual_seed(42)
513
+ torch.cuda.manual_seed_all(42)
514
+ torch.backends.cudnn.benchmark = True
515
+
516
+ try:
517
+ model = main()
518
+ print("Finetuning completed successfully!")
519
+ except Exception as e:
520
+ print(f"An error occurred: {str(e)}")
521
+ ```
522
+
523
  # Training
524
 
525
  The model was trained on a cleaned subset of Russian Wikipedia articles using the following parameters: