ctheodoris madhavanvvs commited on
Commit
d471d1b
·
verified ·
1 Parent(s): 18a2ca6

add fallback task_type for LoraConfig to support different PEFT versions (#538)

Browse files

- add fallback task_type for LoraConfig to support different PEFT versions (7e575a07f75854fb428378a36fda54d912c1136a)


Co-authored-by: Madhavan Venkatesh <[email protected]>

Files changed (1) hide show
  1. geneformer/perturber_utils.py +13 -7
geneformer/perturber_utils.py CHANGED
@@ -138,13 +138,19 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
- peft_config = LoraConfig(
142
- lora_alpha=128,
143
- lora_dropout=0.1,
144
- r=64,
145
- bias="none",
146
- task_type="TokenClassification",
147
- )
 
 
 
 
 
 
148
  else:
149
  quantize_config = None
150
  peft_config = None
 
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
+ lora_config_params = {
142
+ "lora_alpha": 128,
143
+ "lora_dropout": 0.1,
144
+ "r": 64,
145
+ "bias": "none"
146
+ }
147
+
148
+ # Try with TokenClassification first, fallback to TOKEN_CLS if needed
149
+ try:
150
+ peft_config = LoraConfig(**lora_config_params, task_type="TokenClassification")
151
+ except ValueError:
152
+ # Some versions use TOKEN_CLS instead of TokenClassification
153
+ peft_config = LoraConfig(**lora_config_params, task_type="TOKEN_CLS")
154
  else:
155
  quantize_config = None
156
  peft_config = None