Christina Theodoris commited on
Commit
4364d31
·
1 Parent(s): 76a78a0

move V1 autoformatting to after validate_options

Browse files
geneformer/classifier.py CHANGED
@@ -234,10 +234,6 @@ class Classifier:
234
  self.token_dictionary_file = token_dictionary_file
235
  self.nproc = nproc
236
  self.ngpu = ngpu
237
-
238
- if self.model_version == "V1":
239
- from . import TOKEN_DICTIONARY_FILE_30M
240
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
241
 
242
  if self.training_args is None:
243
  logger.warning(
@@ -258,7 +254,10 @@ class Classifier:
258
  ] = self.cell_state_dict["states"]
259
 
260
  # load token dictionary (Ensembl IDs:token)
261
- if self.token_dictionary_file is None:
 
 
 
262
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE
263
  with open(self.token_dictionary_file, "rb") as f:
264
  self.gene_token_dict = pickle.load(f)
 
234
  self.token_dictionary_file = token_dictionary_file
235
  self.nproc = nproc
236
  self.ngpu = ngpu
 
 
 
 
237
 
238
  if self.training_args is None:
239
  logger.warning(
 
254
  ] = self.cell_state_dict["states"]
255
 
256
  # load token dictionary (Ensembl IDs:token)
257
+ if self.model_version == "V1":
258
+ from . import TOKEN_DICTIONARY_FILE_30M
259
+ self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
260
+ elif self.token_dictionary_file is None:
261
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE
262
  with open(self.token_dictionary_file, "rb") as f:
263
  self.gene_token_dict = pickle.load(f)
geneformer/emb_extractor.py CHANGED
@@ -518,6 +518,8 @@ class EmbExtractor:
518
  self.summary_stat = summary_stat
519
  self.exact_summary_stat = None
520
 
 
 
521
  if self.model_version == "V1":
522
  from . import TOKEN_DICTIONARY_FILE_30M
523
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
@@ -527,8 +529,6 @@ class EmbExtractor:
527
  "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
528
  )
529
 
530
- self.validate_options()
531
-
532
  # load token dictionary (Ensembl IDs:token)
533
  if self.token_dictionary_file is None:
534
  token_dictionary_file = TOKEN_DICTIONARY_FILE
 
518
  self.summary_stat = summary_stat
519
  self.exact_summary_stat = None
520
 
521
+ self.validate_options()
522
+
523
  if self.model_version == "V1":
524
  from . import TOKEN_DICTIONARY_FILE_30M
525
  self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
 
529
  "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
530
  )
531
 
 
 
532
  # load token dictionary (Ensembl IDs:token)
533
  if self.token_dictionary_file is None:
534
  token_dictionary_file = TOKEN_DICTIONARY_FILE
geneformer/in_silico_perturber.py CHANGED
@@ -231,7 +231,9 @@ class InSilicoPerturber:
231
  self.nproc = nproc
232
  self.model_version = model_version
233
  self.token_dictionary_file = token_dictionary_file
234
- self.clear_mem_ncells = clear_mem_ncells
 
 
235
 
236
  if self.model_version == "V1":
237
  from . import TOKEN_DICTIONARY_FILE_30M
@@ -245,10 +247,8 @@ class InSilicoPerturber:
245
  self.emb_mode = "cell_and_gene"
246
  logger.warning(
247
  "model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
248
- )
249
-
250
- self.validate_options()
251
-
252
  # load token dictionary (Ensembl IDs:token)
253
  if self.token_dictionary_file is None:
254
  token_dictionary_file = TOKEN_DICTIONARY_FILE
 
231
  self.nproc = nproc
232
  self.model_version = model_version
233
  self.token_dictionary_file = token_dictionary_file
234
+ self.clear_mem_ncells = clear_mem_ncells
235
+
236
+ self.validate_options()
237
 
238
  if self.model_version == "V1":
239
  from . import TOKEN_DICTIONARY_FILE_30M
 
247
  self.emb_mode = "cell_and_gene"
248
  logger.warning(
249
  "model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
250
+ )
251
+
 
 
252
  # load token dictionary (Ensembl IDs:token)
253
  if self.token_dictionary_file is None:
254
  token_dictionary_file = TOKEN_DICTIONARY_FILE