File size: 1,546 Bytes
1cdcee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

import  optimizer_config
import  training_loop
from transformer import models
from transformer import text_dataset

# Training setup.
training_loop.Trainer:
  model_definition = @models.DecoderOnlyLanguageModel

  num_steps = 250_000
  status_every_steps = 10
  log_every_steps = 1000
  test_every_steps = 1000
  num_test_steps = 400
  generate_every_steps = 5000
  print_input_every_steps = 5000
  checkpoint_every_steps = 5000
  save_checkpoints = True
  restore_checkpoints = True
  use_separate_metric_directories = False

  optimizer_factory = @optimizer_config.FlaxAdafactorConfig()
  learning_rate_schedule = @optimizer_config.lr_cosine_decay
  max_scheduled_steps = 0   # Use num_steps as max_scheduled_steps.
  warmup_steps = 1000
  learning_rate_multiplier = 1.0

  rng_key_names = ("dropout", "sample")

text_dataset.load_text_dataset:
  verbose = False  # if true, prints the start of every book/repo read from disk

# Use cosine decay to max_scheduled_steps, as described in Chinchilla:
# https://arxiv.org/abs/2203.15556
optimizer_config.lr_cosine_decay:
    max_lr = 0.01
    min_lr = 0.001
    decay_after = True
    spike_steps = 0
    spike_lr = 0.0


# Adam optimizer configuration.
# optimizer_config.AdamConfig:
#   learning_rate = 0.05  # Will be multiplied by the LR schedule.
#   beta1 = 0.9
#   beta2 = 0.98
#   weight_decay_rate = 0.0

# Adafactor optimizer configuration.
optimizer_config.FlaxAdafactorConfig:
  learning_rate = 1.0   # Will be multiplied by the LR schedule.
  beta1 = 0.9           # Can be "None".