File size: 1,901 Bytes
22f52a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# lightning.pytorch==2.3.3
seed_everything: 0
trainer:
  precision: bf16-mixed
  max_steps: 50000
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    num_workers: 1
    train_batch_size: 64
    shuffle_train: true
    train_dataset:
      class_path: lightning_ir.RunDataset
      init_args:
        run_path_or_id: msmarco-passage/train/rank-distillm/set-encoder
        depth: 100
        sample_size: 8
        sampling_strategy: log_random
        targets: score
        normalize_targets: false
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.SpladeConfig
      init_args:
        query_pooling_strategy: max
        doc_pooling_strategy: max
        projection: mlm
        sparsification: relu_log
        embedding_dim: 30522
        similarity_function: dot
        query_expansion: false
        attend_to_query_expanded_tokens: false
        query_mask_scoring_tokens: null
        query_aggregation_function: sum
        doc_expansion: false
        attend_to_doc_expanded_tokens: false
        doc_mask_scoring_tokens: null
        normalize: false
        add_marker_tokens: false
        query_length: 32
        doc_length: 256
    loss_functions:
    - - class_path: lightning_ir.SupervisedMarginMSE
      - 0.05
    - class_path: lightning_ir.KLDivergence
    - class_path: lightning_ir.FLOPSRegularization
      init_args:
        query_weight: 0.01
        doc_weight: 0.02
    - class_path: lightning_ir.InBatchCrossEntropy
      init_args:
        pos_sampling_technique: first
        neg_sampling_technique: first
        max_num_neg_samples: 8
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 2.0e-05
lr_scheduler:
  class_path: lightning_ir.ConstantLRSchedulerWithLinearWarmup
  init_args:
    num_warmup_steps: 5000
    num_delay_steps: 0