Christina Theodoris
commited on
Commit
·
efc403d
1
Parent(s):
f26136d
move model to device after checking
Browse files
geneformer/perturber_utils.py
CHANGED
@@ -193,21 +193,30 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
193 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
194 |
if not quantize:
|
195 |
# Only move non-quantized models
|
196 |
-
|
197 |
-
model = model.to(device)
|
198 |
elif os.path.exists(adapter_config_path):
|
199 |
# If adapter files exist, load them into the model using PEFT's from_pretrained
|
200 |
model = PeftModel.from_pretrained(model, model_directory)
|
201 |
-
model
|
202 |
print("loading lora weights")
|
203 |
elif peft_config:
|
204 |
# Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
|
205 |
model.enable_input_require_grads()
|
206 |
model = get_peft_model(model, peft_config)
|
207 |
-
model
|
208 |
|
209 |
return model
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
def quant_layers(model):
|
212 |
layer_nums = []
|
213 |
for name, parameter in model.named_parameters():
|
|
|
193 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
194 |
if not quantize:
|
195 |
# Only move non-quantized models
|
196 |
+
move_to_cuda(model)
|
|
|
197 |
elif os.path.exists(adapter_config_path):
|
198 |
# If adapter files exist, load them into the model using PEFT's from_pretrained
|
199 |
model = PeftModel.from_pretrained(model, model_directory)
|
200 |
+
move_to_cuda(model)
|
201 |
print("loading lora weights")
|
202 |
elif peft_config:
|
203 |
# Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
|
204 |
model.enable_input_require_grads()
|
205 |
model = get_peft_model(model, peft_config)
|
206 |
+
move_to_cuda(model)
|
207 |
|
208 |
return model
|
209 |
|
210 |
+
|
211 |
+
def move_to_cuda(model):
|
212 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
213 |
+
# get what device model is currently on
|
214 |
+
model_device = next(model.parameters()).device
|
215 |
+
# Check if the model is on the CPU and move to cuda if necessary
|
216 |
+
if (model_device.type == 'cpu') and (device == "cuda"):
|
217 |
+
model.to(device)
|
218 |
+
|
219 |
+
|
220 |
def quant_layers(model):
|
221 |
layer_nums = []
|
222 |
for name, parameter in model.named_parameters():
|