Christina Theodoris commited on
Commit
efc403d
·
1 Parent(s): f26136d

move model to device after checking

Browse files
Files changed (1) hide show
  1. geneformer/perturber_utils.py +13 -4
geneformer/perturber_utils.py CHANGED
@@ -193,21 +193,30 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
193
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
194
  if not quantize:
195
  # Only move non-quantized models
196
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
197
- model = model.to(device)
198
  elif os.path.exists(adapter_config_path):
199
  # If adapter files exist, load them into the model using PEFT's from_pretrained
200
  model = PeftModel.from_pretrained(model, model_directory)
201
- model = model.to(device)
202
  print("loading lora weights")
203
  elif peft_config:
204
  # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
205
  model.enable_input_require_grads()
206
  model = get_peft_model(model, peft_config)
207
- model = model.to(device)
208
 
209
  return model
210
 
 
 
 
 
 
 
 
 
 
 
211
  def quant_layers(model):
212
  layer_nums = []
213
  for name, parameter in model.named_parameters():
 
193
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
194
  if not quantize:
195
  # Only move non-quantized models
196
+ move_to_cuda(model)
 
197
  elif os.path.exists(adapter_config_path):
198
  # If adapter files exist, load them into the model using PEFT's from_pretrained
199
  model = PeftModel.from_pretrained(model, model_directory)
200
+ move_to_cuda(model)
201
  print("loading lora weights")
202
  elif peft_config:
203
  # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
204
  model.enable_input_require_grads()
205
  model = get_peft_model(model, peft_config)
206
+ move_to_cuda(model)
207
 
208
  return model
209
 
210
+
211
+ def move_to_cuda(model):
212
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
213
+ # get what device model is currently on
214
+ model_device = next(model.parameters()).device
215
+ # Check if the model is on the CPU and move to cuda if necessary
216
+ if (model_device.type == 'cpu') and (device == "cuda"):
217
+ model.to(device)
218
+
219
+
220
  def quant_layers(model):
221
  layer_nums = []
222
  for name, parameter in model.named_parameters():