File size: 1,901 Bytes
22f52a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# lightning.pytorch==2.3.3
seed_everything: 0
trainer:
precision: bf16-mixed
max_steps: 50000
data:
class_path: lightning_ir.LightningIRDataModule
init_args:
num_workers: 1
train_batch_size: 64
shuffle_train: true
train_dataset:
class_path: lightning_ir.RunDataset
init_args:
run_path_or_id: msmarco-passage/train/rank-distillm/set-encoder
depth: 100
sample_size: 8
sampling_strategy: log_random
targets: score
normalize_targets: false
model:
class_path: lightning_ir.BiEncoderModule
init_args:
model_name_or_path: bert-base-uncased
config:
class_path: lightning_ir.SpladeConfig
init_args:
query_pooling_strategy: max
doc_pooling_strategy: max
projection: mlm
sparsification: relu_log
embedding_dim: 30522
similarity_function: dot
query_expansion: false
attend_to_query_expanded_tokens: false
query_mask_scoring_tokens: null
query_aggregation_function: sum
doc_expansion: false
attend_to_doc_expanded_tokens: false
doc_mask_scoring_tokens: null
normalize: false
add_marker_tokens: false
query_length: 32
doc_length: 256
loss_functions:
- - class_path: lightning_ir.SupervisedMarginMSE
- 0.05
- class_path: lightning_ir.KLDivergence
- class_path: lightning_ir.FLOPSRegularization
init_args:
query_weight: 0.01
doc_weight: 0.02
- class_path: lightning_ir.InBatchCrossEntropy
init_args:
pos_sampling_technique: first
neg_sampling_technique: first
max_num_neg_samples: 8
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 2.0e-05
lr_scheduler:
class_path: lightning_ir.ConstantLRSchedulerWithLinearWarmup
init_args:
num_warmup_steps: 5000
num_delay_steps: 0
|