Christina Theodoris commited on
Commit
d319fef
·
1 Parent(s): 31bf641

update with V2 models

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. MANIFEST.in +9 -4
  2. README.md +9 -16
  3. config.json +7 -7
  4. examples/cell_classification.ipynb +6 -7
  5. examples/extract_and_plot_cell_embeddings.ipynb +6 -7
  6. examples/gene_classification.ipynb +10 -9
  7. examples/in_silico_perturbation.ipynb +10 -13
  8. examples/multitask_cell_classification.ipynb +3 -3
  9. examples/tokenizing_scRNAseq_data.ipynb +4 -8
  10. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +0 -24
  11. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +0 -3
  12. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json +0 -35
  13. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt +0 -3
  14. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin +0 -3
  15. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth +0 -3
  16. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt +0 -3
  17. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json +0 -150
  18. fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin +0 -3
  19. geneformer/__init__.py +9 -4
  20. geneformer/classifier.py +13 -2
  21. geneformer/emb_extractor.py +15 -0
  22. geneformer/ensembl_mapping_dict_gc95M.pkl +0 -3
  23. geneformer/evaluation_utils.py +3 -8
  24. geneformer/gene_median_dictionary_gc95M.pkl +0 -3
  25. geneformer/gene_name_id_dict_gc95M.pkl +0 -3
  26. geneformer/in_silico_perturber.py +20 -0
  27. geneformer/in_silico_perturber_stats.py +12 -1
  28. geneformer/perturber_utils.py +28 -65
  29. geneformer/token_dictionary_gc95M.pkl +0 -3
  30. geneformer/tokenizer.py +24 -8
  31. generation_config.json +1 -1
  32. gf-12L-30M-i2048/config.json +0 -23
  33. gf-12L-30M-i2048/pytorch_model.bin +0 -3
  34. gf-12L-30M-i2048/training_args.bin +0 -3
  35. gf-12L-95M-i4096/config.json +0 -24
  36. gf-12L-95M-i4096/generation_config.json +0 -5
  37. gf-12L-95M-i4096/model.safetensors +0 -3
  38. gf-12L-95M-i4096/training_args.bin +0 -3
  39. gf-12L-95M-i4096_CLcancer/config.json +0 -25
  40. gf-12L-95M-i4096_CLcancer/generation_config.json +0 -5
  41. gf-12L-95M-i4096_CLcancer/model.safetensors +0 -3
  42. gf-12L-95M-i4096_CLcancer/training_args.bin +0 -3
  43. gf-20L-95M-i4096/config.json +0 -24
  44. gf-20L-95M-i4096/generation_config.json +0 -5
  45. gf-20L-95M-i4096/model.safetensors +0 -3
  46. gf-20L-95M-i4096/training_args.bin +0 -3
  47. gf-6L-30M-i2048/config.json +0 -23
  48. gf-6L-30M-i2048/model.safetensors +0 -3
  49. gf-6L-30M-i2048/pytorch_model.bin +0 -3
  50. gf-6L-30M-i2048/training_args.bin +0 -3
MANIFEST.in CHANGED
@@ -1,4 +1,9 @@
1
- include geneformer/gene_median_dictionary_gc95M.pkl
2
- include geneformer/gene_name_id_dict_gc95M.pkl
3
- include geneformer/ensembl_mapping_dict_gc95M.pkl
4
- include geneformer/token_dictionary_gc95M.pkl
 
 
 
 
 
 
1
+ include geneformer/gene_median_dictionary_gc104m.pkl
2
+ include geneformer/gene_name_id_dict_gc104m.pkl
3
+ include geneformer/ensembl_mapping_dict_gc104m.pkl
4
+ include geneformer/token_dictionary_gc104m.pkl
5
+
6
+ include geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30m.pkl
7
+ include geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30m.pkl
8
+ include geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30m.pkl
9
+ include geneformer/gene_dictionaries_30m/token_dictionary_gc30m.pkl
README.md CHANGED
@@ -9,35 +9,28 @@ tags:
9
  Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
10
 
11
  - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
12
- - See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies.
13
  - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
14
 
15
  # Model Description
16
- Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
17
 
18
- Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
19
 
20
  The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
21
 
22
  We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
23
 
24
- During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
25
 
26
  The repository includes the following pretrained models:
27
 
28
- L=layers\
29
- M=millions of cells used for pretraining\
30
- i=input size\
31
- (pretraining date)
32
 
33
- - GF-6L-30M-i2048 (June 2021)
34
- - GF-12L-30M-i2048 (June 2021)
35
- - GF-12L-95M-i4096 (April 2024)
36
- - GF-20L-95M-i4096 (April 2024)
37
 
38
- The current default model in the main directory of the repository is GF-12L-95M-i4096.
39
-
40
- The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
41
 
42
  # Application
43
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
@@ -87,7 +80,7 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
87
 
88
  Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
89
 
90
- Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
91
 
92
  # Citations
93
  - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
 
9
  Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
10
 
11
  - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
12
+ - See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model, now trained on ~104 million transcriptomes, and our continual learning, multitask learning, and quantization strategies.
13
  - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
14
 
15
  # Model Description
16
+ Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer V1 was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million human single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. The current updated Geneformer V2 is pretrained on ~104 million human single cell transcriptomes (non-cancer). The cancer continual learning V2 variant was continually pretrrained on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
17
 
18
+ Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus (-30M for V1, -104M for V2). The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
19
 
20
  The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
21
 
22
  We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
23
 
24
+ During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational AI model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
25
 
26
  The repository includes the following pretrained models:
27
 
28
+ - Geneformer-V1-10M: original model trained June 2021 on ~30M human single cell transcriptomes, 10M parameters, input size 2048, vocabulary ~25K protein-coding or non-coding RNA genes
29
+ - Geneformer-V2-104M and Geneformer-V2-316M: updated model trained Dec 2024 on ~104M human single cell transcriptomes, 104M or 316M parameters, input size 4096, vocabulary ~20K protein-coding genes
 
 
30
 
31
+ The current default model in the main directory of the repository is Geneformer-V2-316M.
 
 
 
32
 
33
+ The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, Geneformer-V2-104M_CLcancer.
 
 
34
 
35
  # Application
36
  The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
 
80
 
81
  Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
82
 
83
+ Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). Importantly, as usual for deep learning models, there are no uniformly applicable default hyperparameters for Geneformer.
84
 
85
  # Citations
86
  - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
config.json CHANGED
@@ -2,22 +2,22 @@
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
- "attention_probs_dropout_prob": 0.02,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
- "transformers_version": "4.37.1",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
 
2
  "architectures": [
3
  "BertForMaskedLM"
4
  ],
5
+ "attention_probs_dropout_prob": 0.1,
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 1152,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 4608,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
+ "num_attention_heads": 18,
16
+ "num_hidden_layers": 18,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
20
+ "transformers_version": "4.44.2",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 20275
examples/cell_classification.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
14
  "metadata": {},
15
  "source": [
16
- "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
17
  ]
18
  },
19
  {
@@ -69,9 +69,7 @@
69
  " \"seed\": 73,\n",
70
  "}\n",
71
  "\n",
72
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
73
- "# (otherwise the Classifier will use the current default model dictionary)\n",
74
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
75
  "cc = Classifier(classifier=\"cell\",\n",
76
  " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
77
  " filter_data=filter_data_dict,\n",
@@ -80,6 +78,7 @@
80
  " freeze_layers = 2,\n",
81
  " num_crossval_splits = 1,\n",
82
  " forward_batch_size=200,\n",
 
83
  " nproc=16)"
84
  ]
85
  },
@@ -264,8 +263,8 @@
264
  " \"train\": train_ids,\n",
265
  " \"eval\": eval_ids}\n",
266
  "\n",
267
- "# Example 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n",
268
- "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
269
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
270
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
271
  " output_directory=output_dir,\n",
@@ -450,7 +449,7 @@
450
  "name": "python",
451
  "nbconvert_exporter": "python",
452
  "pygments_lexer": "ipython3",
453
- "version": "3.10.15"
454
  }
455
  },
456
  "nbformat": 4,
 
13
  "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
14
  "metadata": {},
15
  "source": [
16
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters."
17
  ]
18
  },
19
  {
 
69
  " \"seed\": 73,\n",
70
  "}\n",
71
  "\n",
72
+ "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
73
  "cc = Classifier(classifier=\"cell\",\n",
74
  " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
75
  " filter_data=filter_data_dict,\n",
 
78
  " freeze_layers = 2,\n",
79
  " num_crossval_splits = 1,\n",
80
  " forward_batch_size=200,\n",
81
+ " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
82
  " nproc=16)"
83
  ]
84
  },
 
263
  " \"train\": train_ids,\n",
264
  " \"eval\": eval_ids}\n",
265
  "\n",
266
+ "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
267
+ "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
268
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
269
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
270
  " output_directory=output_dir,\n",
 
449
  "name": "python",
450
  "nbconvert_exporter": "python",
451
  "pygments_lexer": "ipython3",
452
+ "version": "3.10.13"
453
  }
454
  },
455
  "nbformat": 4,
examples/extract_and_plot_cell_embeddings.ipynb CHANGED
@@ -18,8 +18,7 @@
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
21
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
22
- "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
23
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
24
  " num_classes=3,\n",
25
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
@@ -28,13 +27,13 @@
28
  " emb_label=[\"disease\",\"cell_type\"],\n",
29
  " labels_to_plot=[\"disease\"],\n",
30
  " forward_batch_size=200,\n",
31
- " nproc=16,\n",
32
- " token_dictionary_file=\"./gene_dictionaries_30m/token_dictionary_gc30M.pkl\") # change from current default dictionary for 30M model series\n",
33
  "\n",
34
  "# extracts embedding from input data\n",
35
  "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
36
- "# example dataset for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
37
- "embs = embex.extract_embs(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
38
  " \"path/to/input_data/\",\n",
39
  " \"path/to/output_directory/\",\n",
40
  " \"output_prefix\")\n"
@@ -132,7 +131,7 @@
132
  "name": "python",
133
  "nbconvert_exporter": "python",
134
  "pygments_lexer": "ipython3",
135
- "version": "3.10.15"
136
  }
137
  },
138
  "nbformat": 4,
 
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
21
+ "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
22
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
23
  " num_classes=3,\n",
24
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
 
27
  " emb_label=[\"disease\",\"cell_type\"],\n",
28
  " labels_to_plot=[\"disease\"],\n",
29
  " forward_batch_size=200,\n",
30
+ " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
31
+ " nproc=16)\n",
32
  "\n",
33
  "# extracts embedding from input data\n",
34
  "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
35
+ "# example dataset for V1 model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
36
+ "embs = embex.extract_embs(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
37
  " \"path/to/input_data/\",\n",
38
  " \"path/to/output_directory/\",\n",
39
  " \"output_prefix\")\n"
 
131
  "name": "python",
132
  "nbconvert_exporter": "python",
133
  "pygments_lexer": "ipython3",
134
+ "version": "3.10.13"
135
  }
136
  },
137
  "nbformat": 4,
examples/gene_classification.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  "id": "79539e95-2c9c-4162-835c-f0d158abb15d",
14
  "metadata": {},
15
  "source": [
16
- "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
17
  ]
18
  },
19
  {
@@ -71,15 +71,14 @@
71
  }
72
  ],
73
  "source": [
74
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
75
- "# (otherwise the Classifier will use the current default model dictionary)\n",
76
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
77
  "cc = Classifier(classifier=\"gene\",\n",
78
  " gene_class_dict = gene_class_dict,\n",
79
  " max_ncells = 10_000,\n",
80
  " freeze_layers = 4,\n",
81
  " num_crossval_splits = 5,\n",
82
  " forward_batch_size=200,\n",
 
83
  " nproc=16)"
84
  ]
85
  },
@@ -843,8 +842,8 @@
843
  }
844
  ],
845
  "source": [
846
- "# 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n",
847
- "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
848
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
849
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
850
  " output_directory=output_dir,\n",
@@ -1066,12 +1065,14 @@
1066
  }
1067
  ],
1068
  "source": [
 
1069
  "cc = Classifier(classifier=\"gene\",\n",
1070
  " gene_class_dict = gene_class_dict,\n",
1071
  " max_ncells = 10_000,\n",
1072
  " freeze_layers = 4,\n",
1073
  " num_crossval_splits = 0,\n",
1074
  " forward_batch_size=200,\n",
 
1075
  " nproc=16)"
1076
  ]
1077
  },
@@ -1218,8 +1219,8 @@
1218
  }
1219
  ],
1220
  "source": [
1221
- "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
1222
- "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n",
1223
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
1224
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
1225
  " output_directory=output_dir,\n",
@@ -1243,7 +1244,7 @@
1243
  "name": "python",
1244
  "nbconvert_exporter": "python",
1245
  "pygments_lexer": "ipython3",
1246
- "version": "3.10.15"
1247
  }
1248
  },
1249
  "nbformat": 4,
 
13
  "id": "79539e95-2c9c-4162-835c-f0d158abb15d",
14
  "metadata": {},
15
  "source": [
16
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters."
17
  ]
18
  },
19
  {
 
71
  }
72
  ],
73
  "source": [
74
+ "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
75
  "cc = Classifier(classifier=\"gene\",\n",
76
  " gene_class_dict = gene_class_dict,\n",
77
  " max_ncells = 10_000,\n",
78
  " freeze_layers = 4,\n",
79
  " num_crossval_splits = 5,\n",
80
  " forward_batch_size=200,\n",
81
+ " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
82
  " nproc=16)"
83
  ]
84
  },
 
842
  }
843
  ],
844
  "source": [
845
+ "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
846
+ "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
847
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
848
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
849
  " output_directory=output_dir,\n",
 
1065
  }
1066
  ],
1067
  "source": [
1068
+ "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
1069
  "cc = Classifier(classifier=\"gene\",\n",
1070
  " gene_class_dict = gene_class_dict,\n",
1071
  " max_ncells = 10_000,\n",
1072
  " freeze_layers = 4,\n",
1073
  " num_crossval_splits = 0,\n",
1074
  " forward_batch_size=200,\n",
1075
+ " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
1076
  " nproc=16)"
1077
  ]
1078
  },
 
1219
  }
1220
  ],
1221
  "source": [
1222
+ "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
1223
+ "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
1224
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
1225
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
1226
  " output_directory=output_dir,\n",
 
1244
  "name": "python",
1245
  "nbconvert_exporter": "python",
1246
  "pygments_lexer": "ipython3",
1247
+ "version": "3.10.13"
1248
  }
1249
  },
1250
  "nbformat": 4,
examples/in_silico_perturbation.ipynb CHANGED
@@ -39,9 +39,7 @@
39
  "\n",
40
  "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
  "\n",
42
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
43
- "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
44
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
45
  "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
46
  " num_classes=3,\n",
47
  " filter_data=filter_data_dict,\n",
@@ -49,6 +47,7 @@
49
  " emb_layer=0,\n",
50
  " summary_stat=\"exact_mean\",\n",
51
  " forward_batch_size=256,\n",
 
52
  " nproc=16)\n",
53
  "\n",
54
  "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
@@ -67,9 +66,7 @@
67
  },
68
  "outputs": [],
69
  "source": [
70
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
71
- "# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
72
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
73
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
74
  " perturb_rank_shift=None,\n",
75
  " genes_to_perturb=\"all\",\n",
@@ -77,7 +74,7 @@
77
  " anchor_gene=None,\n",
78
  " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
79
  " num_classes=3,\n",
80
- " emb_mode=\"cell\",\n",
81
  " cell_emb_style=\"mean_pool\",\n",
82
  " filter_data=filter_data_dict,\n",
83
  " cell_states_to_model=cell_states_to_model,\n",
@@ -85,6 +82,7 @@
85
  " max_ncells=2000,\n",
86
  " emb_layer=0,\n",
87
  " forward_batch_size=400,\n",
 
88
  " nproc=16)"
89
  ]
90
  },
@@ -97,7 +95,7 @@
97
  "source": [
98
  "# outputs intermediate files from in silico perturbation\n",
99
  "\n",
100
- "isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
101
  " \"path/to/input_data\",\n",
102
  " \"path/to/isp_output_directory\",\n",
103
  " \"output_prefix\")"
@@ -110,14 +108,13 @@
110
  "metadata": {},
111
  "outputs": [],
112
  "source": [
113
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
114
- "# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
115
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
116
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
117
  " genes_perturbed=\"all\",\n",
118
  " combos=0,\n",
119
  " anchor_gene=None,\n",
120
- " cell_states_to_model=cell_states_to_model)"
 
121
  ]
122
  },
123
  {
@@ -151,7 +148,7 @@
151
  "name": "python",
152
  "nbconvert_exporter": "python",
153
  "pygments_lexer": "ipython3",
154
- "version": "3.10.15"
155
  }
156
  },
157
  "nbformat": 4,
 
39
  "\n",
40
  "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
  "\n",
42
+ "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
43
  "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
44
  " num_classes=3,\n",
45
  " filter_data=filter_data_dict,\n",
 
47
  " emb_layer=0,\n",
48
  " summary_stat=\"exact_mean\",\n",
49
  " forward_batch_size=256,\n",
50
+ " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
51
  " nproc=16)\n",
52
  "\n",
53
  "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
 
66
  },
67
  "outputs": [],
68
  "source": [
69
+ "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
70
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
71
  " perturb_rank_shift=None,\n",
72
  " genes_to_perturb=\"all\",\n",
 
74
  " anchor_gene=None,\n",
75
  " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
76
  " num_classes=3,\n",
77
+ " emb_mode=\"cell\", # OF NOTE: SET TO \"CELL\" FOR V1 MODEL. FOR V2, SHOULD BE \"CLS\" (current default).\n",
78
  " cell_emb_style=\"mean_pool\",\n",
79
  " filter_data=filter_data_dict,\n",
80
  " cell_states_to_model=cell_states_to_model,\n",
 
82
  " max_ncells=2000,\n",
83
  " emb_layer=0,\n",
84
  " forward_batch_size=400,\n",
85
+ " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
86
  " nproc=16)"
87
  ]
88
  },
 
95
  "source": [
96
  "# outputs intermediate files from in silico perturbation\n",
97
  "\n",
98
+ "isp.perturb_data(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
99
  " \"path/to/input_data\",\n",
100
  " \"path/to/isp_output_directory\",\n",
101
  " \"output_prefix\")"
 
108
  "metadata": {},
109
  "outputs": [],
110
  "source": [
111
+ "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
 
 
112
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
113
  " genes_perturbed=\"all\",\n",
114
  " combos=0,\n",
115
  " anchor_gene=None,\n",
116
+ " cell_states_to_model=cell_states_to_model,\n",
117
+ " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL SINCE V1 WAS USED FOR IN SILICO PERTURBATION ABOVE)"
118
  ]
119
  },
120
  {
 
148
  "name": "python",
149
  "nbconvert_exporter": "python",
150
  "pygments_lexer": "ipython3",
151
+ "version": "3.10.13"
152
  }
153
  },
154
  "nbformat": 4,
examples/multitask_cell_classification.ipynb CHANGED
@@ -286,7 +286,7 @@
286
  " filter_data_dict=filter_data_dict,\n",
287
  " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
  " emb_layer=0, # Use the second to last layer\n",
289
- " emb_mode = \"cls\",\n",
290
  " summary_stat=\"exact_mean\",\n",
291
  " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
  " nproc=4\n",
@@ -324,7 +324,7 @@
324
  " perturb_type=perturb_type,\n",
325
  " genes_to_perturb=\"all\", # Perturb all genes\n",
326
  " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
- " emb_mode=\"cls\", # Use CLS token embedding\n",
328
  " cell_states_to_model=cell_states_to_model,\n",
329
  " state_embs_dict=state_embs_dict,\n",
330
  " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
@@ -412,7 +412,7 @@
412
  "name": "python",
413
  "nbconvert_exporter": "python",
414
  "pygments_lexer": "ipython3",
415
- "version": "3.11.5"
416
  }
417
  },
418
  "nbformat": 4,
 
286
  " filter_data_dict=filter_data_dict,\n",
287
  " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
  " emb_layer=0, # Use the second to last layer\n",
289
+ " emb_mode = \"cls\", # Use CLS token embedding for V2 model\n",
290
  " summary_stat=\"exact_mean\",\n",
291
  " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
  " nproc=4\n",
 
324
  " perturb_type=perturb_type,\n",
325
  " genes_to_perturb=\"all\", # Perturb all genes\n",
326
  " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
+ " emb_mode=\"cls\", # Use CLS token embedding for V2 model\n",
328
  " cell_states_to_model=cell_states_to_model,\n",
329
  " state_embs_dict=state_embs_dict,\n",
330
  " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
 
412
  "name": "python",
413
  "nbconvert_exporter": "python",
414
  "pygments_lexer": "ipython3",
415
+ "version": "3.10.13"
416
  }
417
  },
418
  "nbformat": 4,
examples/tokenizing_scRNAseq_data.ipynb CHANGED
@@ -34,12 +34,8 @@
34
  "metadata": {},
35
  "source": [
36
  "**********************************************************************************************************\n",
37
- "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
38
- "#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
39
- "\n",
40
- "#### ADDITIONALLY:\n",
41
- "#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
42
- "#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
43
  ]
44
  },
45
  {
@@ -59,7 +55,7 @@
59
  "metadata": {},
60
  "outputs": [],
61
  "source": [
62
- "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
63
  "tk.tokenize_data(\"loom_data_directory\", \n",
64
  " \"output_directory\", \n",
65
  " \"output_prefix\", \n",
@@ -83,7 +79,7 @@
83
  "name": "python",
84
  "nbconvert_exporter": "python",
85
  "pygments_lexer": "ipython3",
86
- "version": "3.10.15"
87
  }
88
  },
89
  "nbformat": 4,
 
34
  "metadata": {},
35
  "source": [
36
  "**********************************************************************************************************\n",
37
+ "#### OF NOTE: Please ensure the correct token dictionary, gene median file, special token setting, and model input size is used for the correct model version.\n",
38
+ "#### Current defaults are for V2 model series. To auto-select the correct settings for V1, set model_version argument to \"V1\"."
 
 
 
 
39
  ]
40
  },
41
  {
 
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
58
+ "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16) # for V1 model, set model_version=\"V1\"\n",
59
  "tk.tokenize_data(\"loom_data_directory\", \n",
60
  " \"output_directory\", \n",
61
  " \"output_prefix\", \n",
 
79
  "name": "python",
80
  "nbconvert_exporter": "python",
81
  "pygments_lexer": "ipython3",
82
+ "version": "3.10.13"
83
  }
84
  },
85
  "nbformat": 4,
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
- "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.2",
21
- "type_vocab_size": 2,
22
- "use_cache": true,
23
- "vocab_size": 20275
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
3
- size 152363342
 
 
 
 
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json DELETED
@@ -1,35 +0,0 @@
1
- {
2
- "_name_or_path": "/n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/",
3
- "architectures": [
4
- "BertForSequenceClassification"
5
- ],
6
- "attention_probs_dropout_prob": 0.02,
7
- "gradient_checkpointing": false,
8
- "hidden_act": "relu",
9
- "hidden_dropout_prob": 0.02,
10
- "hidden_size": 256,
11
- "id2label": {
12
- "0": "LABEL_0",
13
- "1": "LABEL_1",
14
- "2": "LABEL_2"
15
- },
16
- "initializer_range": 0.02,
17
- "intermediate_size": 512,
18
- "label2id": {
19
- "LABEL_0": 0,
20
- "LABEL_1": 1,
21
- "LABEL_2": 2
22
- },
23
- "layer_norm_eps": 1e-12,
24
- "max_position_embeddings": 2048,
25
- "model_type": "bert",
26
- "num_attention_heads": 4,
27
- "num_hidden_layers": 6,
28
- "pad_token_id": 0,
29
- "position_embedding_type": "absolute",
30
- "problem_type": "single_label_classification",
31
- "transformers_version": "4.6.0",
32
- "type_vocab_size": 2,
33
- "use_cache": true,
34
- "vocab_size": 25426
35
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3ced328122d57a847fc3914732337674500e259a82e64437c67b4954ac2f4e07
3
- size 73720721
 
 
 
 
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:12ead3bad8cf4b853bac87eadeb79c9308ae492e9d29f32da1a2c85e8586108d
3
- size 41115113
 
 
 
 
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd8c0a739c2fe6a9ab4bb8f4a62ad8d7b879efcdceb5376b128a2040ff1bbe62
3
- size 14657
 
 
 
 
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d0797845afdae765a74ddab7966e0e1837617fd8171af8ee6aef9dedce248f2
3
- size 623
 
 
 
 
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json DELETED
@@ -1,150 +0,0 @@
1
- {
2
- "best_metric": 0.39658036828041077,
3
- "best_model_checkpoint": "/n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/models/220224_geneformer_27M_SequenceClassifier_tuning_hCMdCM_L2048_B12_LR1e-05_LScosine_WU500_E1_Oadamw_F2/run-8429a330/checkpoint-7020",
4
- "epoch": 0.9,
5
- "global_step": 7020,
6
- "is_hyper_param_search": true,
7
- "is_local_process_zero": true,
8
- "is_world_process_zero": true,
9
- "log_history": [
10
- {
11
- "epoch": 0.1,
12
- "learning_rate": 0.00034606438343856935,
13
- "loss": 0.911,
14
- "step": 780
15
- },
16
- {
17
- "epoch": 0.1,
18
- "eval_accuracy": 0.4531576503366612,
19
- "eval_loss": 1.4550466537475586,
20
- "eval_runtime": 66.5164,
21
- "eval_samples_per_second": 259.004,
22
- "step": 780
23
- },
24
- {
25
- "epoch": 0.2,
26
- "learning_rate": 0.0006921287668771387,
27
- "loss": 0.6273,
28
- "step": 1560
29
- },
30
- {
31
- "epoch": 0.2,
32
- "eval_accuracy": 0.5953680055723242,
33
- "eval_loss": 0.846651554107666,
34
- "eval_runtime": 66.1267,
35
- "eval_samples_per_second": 260.53,
36
- "step": 1560
37
- },
38
- {
39
- "epoch": 0.3,
40
- "learning_rate": 0.0007330550166223805,
41
- "loss": 0.5592,
42
- "step": 2340
43
- },
44
- {
45
- "epoch": 0.3,
46
- "eval_accuracy": 0.5935105641978176,
47
- "eval_loss": 1.0599186420440674,
48
- "eval_runtime": 66.2608,
49
- "eval_samples_per_second": 260.003,
50
- "step": 2340
51
- },
52
- {
53
- "epoch": 0.4,
54
- "learning_rate": 0.0006283471571048975,
55
- "loss": 0.3714,
56
- "step": 3120
57
- },
58
- {
59
- "epoch": 0.4,
60
- "eval_accuracy": 0.686324587880195,
61
- "eval_loss": 1.184874415397644,
62
- "eval_runtime": 66.1411,
63
- "eval_samples_per_second": 260.473,
64
- "step": 3120
65
- },
66
- {
67
- "epoch": 0.5,
68
- "learning_rate": 0.0005236392975874146,
69
- "loss": 0.2976,
70
- "step": 3900
71
- },
72
- {
73
- "epoch": 0.5,
74
- "eval_accuracy": 0.7681100534014396,
75
- "eval_loss": 0.6318939328193665,
76
- "eval_runtime": 66.3309,
77
- "eval_samples_per_second": 259.728,
78
- "step": 3900
79
- },
80
- {
81
- "epoch": 0.6,
82
- "learning_rate": 0.0004189314380699318,
83
- "loss": 0.2564,
84
- "step": 4680
85
- },
86
- {
87
- "epoch": 0.6,
88
- "eval_accuracy": 0.7807058277223126,
89
- "eval_loss": 0.7283642888069153,
90
- "eval_runtime": 66.3416,
91
- "eval_samples_per_second": 259.686,
92
- "step": 4680
93
- },
94
- {
95
- "epoch": 0.7,
96
- "learning_rate": 0.0003142235785524487,
97
- "loss": 0.2336,
98
- "step": 5460
99
- },
100
- {
101
- "epoch": 0.7,
102
- "eval_accuracy": 0.8563965637334572,
103
- "eval_loss": 0.5184123516082764,
104
- "eval_runtime": 66.3416,
105
- "eval_samples_per_second": 259.686,
106
- "step": 5460
107
- },
108
- {
109
- "epoch": 0.8,
110
- "learning_rate": 0.0002095157190349659,
111
- "loss": 0.1731,
112
- "step": 6240
113
- },
114
- {
115
- "epoch": 0.8,
116
- "eval_accuracy": 0.8288832133735778,
117
- "eval_loss": 0.5823884010314941,
118
- "eval_runtime": 66.1535,
119
- "eval_samples_per_second": 260.425,
120
- "step": 6240
121
- },
122
- {
123
- "epoch": 0.9,
124
- "learning_rate": 0.00010480785951748295,
125
- "loss": 0.1451,
126
- "step": 7020
127
- },
128
- {
129
- "epoch": 0.9,
130
- "eval_accuracy": 0.886812166241003,
131
- "eval_loss": 0.39658036828041077,
132
- "eval_runtime": 66.3555,
133
- "eval_samples_per_second": 259.632,
134
- "step": 7020
135
- }
136
- ],
137
- "max_steps": 7800,
138
- "num_train_epochs": 1,
139
- "total_flos": 0,
140
- "trial_name": null,
141
- "trial_params": {
142
- "learning_rate": 0.0008039341830649843,
143
- "lr_scheduler_type": "polynomial",
144
- "num_train_epochs": 1,
145
- "per_device_train_batch_size": 12,
146
- "seed": 73.15243080311434,
147
- "warmup_steps": 1812.6785581609881,
148
- "weight_decay": 0.2588277764570262
149
- }
150
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ffee119596c99b50a422b2f80103f4c44f7e25c2ea0e457fe224bad59f1f955
3
- size 2607
 
 
 
 
geneformer/__init__.py CHANGED
@@ -4,10 +4,15 @@ from pathlib import Path
4
 
5
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
 
7
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
8
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
9
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
10
- ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
 
 
 
 
 
11
 
12
  from . import (
13
  collator_for_classification,
 
4
 
5
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
 
7
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc104M.pkl"
8
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc104M.pkl"
9
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc104M.pkl"
10
+ ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc104M.pkl"
11
+
12
+ GENE_MEDIAN_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl"
13
+ TOKEN_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/token_dictionary_gc30M.pkl"
14
+ ENSEMBL_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl"
15
+ ENSEMBL_MAPPING_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl"
16
 
17
  from . import (
18
  collator_for_classification,
geneformer/classifier.py CHANGED
@@ -92,6 +92,7 @@ class Classifier:
92
  "no_eval": {bool},
93
  "stratify_splits_col": {None, str},
94
  "forward_batch_size": {int},
 
95
  "token_dictionary_file": {None, str},
96
  "nproc": {int},
97
  "ngpu": {int},
@@ -115,6 +116,7 @@ class Classifier:
115
  stratify_splits_col=None,
116
  no_eval=False,
117
  forward_batch_size=100,
 
118
  token_dictionary_file=None,
119
  nproc=4,
120
  ngpu=1,
@@ -191,6 +193,9 @@ class Classifier:
191
  | Otherwise, will perform eval during training.
192
  forward_batch_size : int
193
  | Batch size for forward pass (for evaluation, not training).
 
 
 
194
  token_dictionary_file : None, str
195
  | Default is to use token dictionary file from Geneformer
196
  | Otherwise, will load custom gene token dictionary.
@@ -225,14 +230,20 @@ class Classifier:
225
  self.stratify_splits_col = stratify_splits_col
226
  self.no_eval = no_eval
227
  self.forward_batch_size = forward_batch_size
 
228
  self.token_dictionary_file = token_dictionary_file
229
  self.nproc = nproc
230
  self.ngpu = ngpu
231
 
 
 
 
 
232
  if self.training_args is None:
233
  logger.warning(
234
  "Hyperparameter tuning is highly recommended for optimal results. "
235
- "No training_args provided; using default hyperparameters."
 
236
  )
237
 
238
  self.validate_options()
@@ -1319,7 +1330,7 @@ class Classifier:
1319
  ##### Evaluate the model #####
1320
  labels = id_class_dict.keys()
1321
  y_pred, y_true, logits_list = eu.classifier_predict(
1322
- model, self.classifier, eval_data, self.forward_batch_size
1323
  )
1324
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1325
  y_pred, y_true, logits_list, num_classes, labels
 
92
  "no_eval": {bool},
93
  "stratify_splits_col": {None, str},
94
  "forward_batch_size": {int},
95
+ "model_version": {"V1", "V2"},
96
  "token_dictionary_file": {None, str},
97
  "nproc": {int},
98
  "ngpu": {int},
 
116
  stratify_splits_col=None,
117
  no_eval=False,
118
  forward_batch_size=100,
119
+ model_version="V2",
120
  token_dictionary_file=None,
121
  nproc=4,
122
  ngpu=1,
 
193
  | Otherwise, will perform eval during training.
194
  forward_batch_size : int
195
  | Batch size for forward pass (for evaluation, not training).
196
+ model_version : str
197
+ | To auto-select settings for model version other than current default.
198
+ | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
199
  token_dictionary_file : None, str
200
  | Default is to use token dictionary file from Geneformer
201
  | Otherwise, will load custom gene token dictionary.
 
230
  self.stratify_splits_col = stratify_splits_col
231
  self.no_eval = no_eval
232
  self.forward_batch_size = forward_batch_size
233
+ self.model_version = model_version
234
  self.token_dictionary_file = token_dictionary_file
235
  self.nproc = nproc
236
  self.ngpu = ngpu
237
 
238
+ if self.model_version == "V1":
239
+ from . import TOKEN_DICTIONARY_FILE_30M
240
+ self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
241
+
242
  if self.training_args is None:
243
  logger.warning(
244
  "Hyperparameter tuning is highly recommended for optimal results. "
245
+ "No training_args provided; using default hyperparameters. "
246
+ "Please note: these defaults are not recommended to be used uniformly across tasks."
247
  )
248
 
249
  self.validate_options()
 
1330
  ##### Evaluate the model #####
1331
  labels = id_class_dict.keys()
1332
  y_pred, y_true, logits_list = eu.classifier_predict(
1333
+ model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict
1334
  )
1335
  conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1336
  y_pred, y_true, logits_list, num_classes, labels
geneformer/emb_extractor.py CHANGED
@@ -402,6 +402,7 @@ class EmbExtractor:
402
  "emb_label": {None, list},
403
  "labels_to_plot": {None, list},
404
  "forward_batch_size": {int},
 
405
  "token_dictionary_file": {None, str},
406
  "nproc": {int},
407
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
@@ -422,6 +423,7 @@ class EmbExtractor:
422
  forward_batch_size=100,
423
  nproc=4,
424
  summary_stat=None,
 
425
  token_dictionary_file=None,
426
  ):
427
  """
@@ -472,6 +474,9 @@ class EmbExtractor:
472
  | If mean or median, outputs only approximated mean or median embedding of input data.
473
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
474
  | Non-exact is slower but more memory-efficient.
 
 
 
475
  token_dictionary_file : Path
476
  | Default is the Geneformer token dictionary
477
  | Path to pickle file containing token dictionary (Ensembl ID:token).
@@ -502,6 +507,7 @@ class EmbExtractor:
502
  self.emb_layer = emb_layer
503
  self.emb_label = emb_label
504
  self.labels_to_plot = labels_to_plot
 
505
  self.token_dictionary_file = token_dictionary_file
506
  self.forward_batch_size = forward_batch_size
507
  self.nproc = nproc
@@ -512,6 +518,15 @@ class EmbExtractor:
512
  self.summary_stat = summary_stat
513
  self.exact_summary_stat = None
514
 
 
 
 
 
 
 
 
 
 
515
  self.validate_options()
516
 
517
  # load token dictionary (Ensembl IDs:token)
 
402
  "emb_label": {None, list},
403
  "labels_to_plot": {None, list},
404
  "forward_batch_size": {int},
405
+ "model_version": {"V1", "V2"},
406
  "token_dictionary_file": {None, str},
407
  "nproc": {int},
408
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
 
423
  forward_batch_size=100,
424
  nproc=4,
425
  summary_stat=None,
426
+ model_version="V2",
427
  token_dictionary_file=None,
428
  ):
429
  """
 
474
  | If mean or median, outputs only approximated mean or median embedding of input data.
475
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
476
  | Non-exact is slower but more memory-efficient.
477
+ model_version : str
478
+ | To auto-select settings for model version other than current default.
479
+ | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
480
  token_dictionary_file : Path
481
  | Default is the Geneformer token dictionary
482
  | Path to pickle file containing token dictionary (Ensembl ID:token).
 
507
  self.emb_layer = emb_layer
508
  self.emb_label = emb_label
509
  self.labels_to_plot = labels_to_plot
510
+ self.model_version = model_version
511
  self.token_dictionary_file = token_dictionary_file
512
  self.forward_batch_size = forward_batch_size
513
  self.nproc = nproc
 
518
  self.summary_stat = summary_stat
519
  self.exact_summary_stat = None
520
 
521
+ if self.model_version == "V1":
522
+ from . import TOKEN_DICTIONARY_FILE_30M
523
+ self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
524
+ if self.emb_mode == "cls":
525
+ self.emb_mode = "cell"
526
+ logger.warning(
527
+ "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
528
+ )
529
+
530
  self.validate_options()
531
 
532
  # load token dictionary (Ensembl IDs:token)
geneformer/ensembl_mapping_dict_gc95M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0819bcbd869cfa14279449b037eb9ed1d09a91310e77bd1a19d927465030e95c
3
- size 3957652
 
 
 
 
geneformer/evaluation_utils.py CHANGED
@@ -20,20 +20,15 @@ from sklearn.metrics import (
20
  )
21
  from tqdm.auto import trange
22
 
23
- from . import TOKEN_DICTIONARY_FILE
24
  from .emb_extractor import make_colorbar
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
 
29
- def preprocess_classifier_batch(cell_batch, max_len, label_name):
30
  if max_len is None:
31
  max_len = max([len(i) for i in cell_batch["input_ids"]])
32
 
33
- # load token dictionary (Ensembl IDs:token)
34
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
35
- gene_token_dict = pickle.load(f)
36
-
37
  def pad_label_example(example):
38
  example[label_name] = np.pad(
39
  example[label_name],
@@ -81,7 +76,7 @@ def py_softmax(vector):
81
  return e / e.sum()
82
 
83
 
84
- def classifier_predict(model, classifier_type, evalset, forward_batch_size):
85
  if classifier_type == "gene":
86
  label_name = "labels"
87
  elif classifier_type == "cell":
@@ -104,7 +99,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size):
104
  max_range = min(i + forward_batch_size, evalset_len)
105
  batch_evalset = evalset.select([i for i in range(i, max_range)])
106
  padded_batch = preprocess_classifier_batch(
107
- batch_evalset, max_evalset_len, label_name
108
  )
109
  padded_batch.set_format(type="torch")
110
 
 
20
  )
21
  from tqdm.auto import trange
22
 
 
23
  from .emb_extractor import make_colorbar
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
 
28
+ def preprocess_classifier_batch(cell_batch, max_len, label_name, gene_token_dict):
29
  if max_len is None:
30
  max_len = max([len(i) for i in cell_batch["input_ids"]])
31
 
 
 
 
 
32
  def pad_label_example(example):
33
  example[label_name] = np.pad(
34
  example[label_name],
 
76
  return e / e.sum()
77
 
78
 
79
+ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict):
80
  if classifier_type == "gene":
81
  label_name = "labels"
82
  elif classifier_type == "cell":
 
99
  max_range = min(i + forward_batch_size, evalset_len)
100
  batch_evalset = evalset.select([i for i in range(i, max_range)])
101
  padded_batch = preprocess_classifier_batch(
102
+ batch_evalset, max_evalset_len, label_name, gene_token_dict
103
  )
104
  padded_batch.set_format(type="torch")
105
 
geneformer/gene_median_dictionary_gc95M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a51c53f6a771d64508dfaf61529df70e394c53bd20856926117ae5d641a24bf5
3
- size 1512661
 
 
 
 
geneformer/gene_name_id_dict_gc95M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fabfa0c2f49c598c59ae432a32c3499a5908c033756c663b5e0cddf58deea8e1
3
- size 1660882
 
 
 
 
geneformer/in_silico_perturber.py CHANGED
@@ -72,6 +72,7 @@ class InSilicoPerturber:
72
  "max_ncells": {None, int},
73
  "cell_inds_to_perturb": {"all", dict},
74
  "emb_layer": {-1, 0},
 
75
  "token_dictionary_file": {None, str},
76
  "forward_batch_size": {int},
77
  "nproc": {int},
@@ -96,6 +97,7 @@ class InSilicoPerturber:
96
  emb_layer=-1,
97
  forward_batch_size=100,
98
  nproc=4,
 
99
  token_dictionary_file=None,
100
  clear_mem_ncells=1000,
101
  ):
@@ -184,6 +186,9 @@ class InSilicoPerturber:
184
  | Batch size for forward pass.
185
  nproc : int
186
  | Number of CPU processes to use.
 
 
 
187
  token_dictionary_file : Path
188
  | Path to pickle file containing token dictionary (Ensembl ID:token).
189
  clear_mem_ncells : int
@@ -224,9 +229,24 @@ class InSilicoPerturber:
224
  self.emb_layer = emb_layer
225
  self.forward_batch_size = forward_batch_size
226
  self.nproc = nproc
 
227
  self.token_dictionary_file = token_dictionary_file
228
  self.clear_mem_ncells = clear_mem_ncells
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  self.validate_options()
231
 
232
  # load token dictionary (Ensembl IDs:token)
 
72
  "max_ncells": {None, int},
73
  "cell_inds_to_perturb": {"all", dict},
74
  "emb_layer": {-1, 0},
75
+ "model_version": {"V1", "V2"},
76
  "token_dictionary_file": {None, str},
77
  "forward_batch_size": {int},
78
  "nproc": {int},
 
97
  emb_layer=-1,
98
  forward_batch_size=100,
99
  nproc=4,
100
+ model_version="V2",
101
  token_dictionary_file=None,
102
  clear_mem_ncells=1000,
103
  ):
 
186
  | Batch size for forward pass.
187
  nproc : int
188
  | Number of CPU processes to use.
189
+ model_version : str
190
+ | To auto-select settings for model version other than current default.
191
+ | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
192
  token_dictionary_file : Path
193
  | Path to pickle file containing token dictionary (Ensembl ID:token).
194
  clear_mem_ncells : int
 
229
  self.emb_layer = emb_layer
230
  self.forward_batch_size = forward_batch_size
231
  self.nproc = nproc
232
+ self.model_version = model_version
233
  self.token_dictionary_file = token_dictionary_file
234
  self.clear_mem_ncells = clear_mem_ncells
235
 
236
+ if self.model_version == "V1":
237
+ from . import TOKEN_DICTIONARY_FILE_30M
238
+ self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
239
+ if self.emb_mode == "cls":
240
+ self.emb_mode = "cell"
241
+ logger.warning(
242
+ "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
243
+ )
244
+ if self.emb_mode == "cls_and_gene":
245
+ self.emb_mode = "cell_and_gene"
246
+ logger.warning(
247
+ "model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
248
+ )
249
+
250
  self.validate_options()
251
 
252
  # load token dictionary (Ensembl IDs:token)
geneformer/in_silico_perturber_stats.py CHANGED
@@ -676,6 +676,7 @@ class InSilicoPerturberStats:
676
  "anchor_gene": {None, str},
677
  "cell_states_to_model": {None, dict},
678
  "pickle_suffix": {None, str},
 
679
  }
680
 
681
  def __init__(
@@ -686,6 +687,7 @@ class InSilicoPerturberStats:
686
  anchor_gene=None,
687
  cell_states_to_model=None,
688
  pickle_suffix="_raw.pickle",
 
689
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
690
  gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
691
  ):
@@ -713,7 +715,7 @@ class InSilicoPerturberStats:
713
  | analyzes data for anchor gene perturbed in combination with each other gene.
714
  | However, if combos=0 and anchor_gene="ENSG00000136574":
715
  | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
716
- cell_states_to_model: None, dict
717
  | Cell states to model if testing perturbations that achieve goal state change.
718
  | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
719
  | state_key: key specifying name of column in .dataset that defines the start/goal states
@@ -724,6 +726,9 @@ class InSilicoPerturberStats:
724
  | "start_state": "dcm",
725
  | "goal_state": "nf",
726
  | "alt_states": ["hcm", "other1", "other2"]}
 
 
 
727
  token_dictionary_file : Path
728
  | Path to pickle file containing token dictionary (Ensembl ID:token).
729
  gene_name_id_dictionary_file : Path
@@ -736,9 +741,15 @@ class InSilicoPerturberStats:
736
  self.anchor_gene = anchor_gene
737
  self.cell_states_to_model = cell_states_to_model
738
  self.pickle_suffix = pickle_suffix
 
739
 
740
  self.validate_options()
741
 
 
 
 
 
 
742
  # load token dictionary (Ensembl IDs:token)
743
  with open(token_dictionary_file, "rb") as f:
744
  self.gene_token_dict = pickle.load(f)
 
676
  "anchor_gene": {None, str},
677
  "cell_states_to_model": {None, dict},
678
  "pickle_suffix": {None, str},
679
+ "model_version": {"V1", "V2"},
680
  }
681
 
682
  def __init__(
 
687
  anchor_gene=None,
688
  cell_states_to_model=None,
689
  pickle_suffix="_raw.pickle",
690
+ model_version="V2",
691
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
692
  gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
693
  ):
 
715
  | analyzes data for anchor gene perturbed in combination with each other gene.
716
  | However, if combos=0 and anchor_gene="ENSG00000136574":
717
  | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
718
+ cell_states_to_model : None, dict
719
  | Cell states to model if testing perturbations that achieve goal state change.
720
  | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
721
  | state_key: key specifying name of column in .dataset that defines the start/goal states
 
726
  | "start_state": "dcm",
727
  | "goal_state": "nf",
728
  | "alt_states": ["hcm", "other1", "other2"]}
729
+ model_version : str
730
+ | To auto-select settings for model version other than current default.
731
+ | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
732
  token_dictionary_file : Path
733
  | Path to pickle file containing token dictionary (Ensembl ID:token).
734
  gene_name_id_dictionary_file : Path
 
741
  self.anchor_gene = anchor_gene
742
  self.cell_states_to_model = cell_states_to_model
743
  self.pickle_suffix = pickle_suffix
744
+ self.model_version = model_version
745
 
746
  self.validate_options()
747
 
748
+ if self.model_version == "V1":
749
+ from . import ENSEMBL_DICTIONARY_FILE_30M, TOKEN_DICTIONARY_FILE_30M
750
+ token_dictionary_file=TOKEN_DICTIONARY_FILE_30M
751
+ gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE_30M
752
+
753
  # load token dictionary (Ensembl IDs:token)
754
  with open(token_dictionary_file, "rb") as f:
755
  self.gene_token_dict = pickle.load(f)
geneformer/perturber_utils.py CHANGED
@@ -17,11 +17,6 @@ from transformers import (
17
  BitsAndBytesConfig,
18
  )
19
 
20
- from . import (
21
- TOKEN_DICTIONARY_FILE,
22
- ENSEMBL_DICTIONARY_FILE,
23
- )
24
-
25
  logger = logging.getLogger(__name__)
26
 
27
 
@@ -127,7 +122,10 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
127
  output_hidden_states = (mode == "eval")
128
 
129
  # Quantization logic
130
- if quantize:
 
 
 
131
  if inference_only:
132
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
133
  peft_config = None
@@ -138,19 +136,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
138
  bnb_4bit_quant_type="nf4",
139
  bnb_4bit_compute_dtype=torch.bfloat16,
140
  )
141
- lora_config_params = {
142
- "lora_alpha": 128,
143
- "lora_dropout": 0.1,
144
- "r": 64,
145
- "bias": "none"
146
- }
147
-
148
- # Try with TokenClassification first, fallback to TOKEN_CLS if needed
149
  try:
150
- peft_config = LoraConfig(**lora_config_params, task_type="TokenClassification")
151
- except ValueError:
152
- # Some versions use TOKEN_CLS instead of TokenClassification
153
- peft_config = LoraConfig(**lora_config_params, task_type="TOKEN_CLS")
 
 
 
 
 
 
 
 
 
 
 
154
  else:
155
  quantize_config = None
156
  peft_config = None
@@ -187,14 +188,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
187
  model.eval()
188
 
189
  # Handle device placement and PEFT
 
 
190
  if not quantize:
191
  # Only move non-quantized models
192
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
193
  model = model.to(device)
 
 
 
 
 
194
  elif peft_config:
195
  # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
196
  model.enable_input_require_grads()
197
  model = get_peft_model(model, peft_config)
 
198
 
199
  return model
200
 
@@ -883,50 +892,4 @@ def validate_cell_states_to_model(cell_states_to_model):
883
  "'goal_state': 'nf', "
884
  "'alt_states': ['hcm', 'other1', 'other2']}"
885
  )
886
- raise
887
-
888
-
889
- class GeneIdHandler:
890
- def __init__(self, raise_errors=False):
891
- def invert_dict(dict_obj):
892
- return {v: k for k, v in dict_obj.items()}
893
-
894
- self.raise_errors = raise_errors
895
-
896
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
897
- self.gene_token_dict = pickle.load(f)
898
- self.token_gene_dict = invert_dict(self.gene_token_dict)
899
-
900
- with open(ENSEMBL_DICTIONARY_FILE, "rb") as f:
901
- self.id_gene_dict = pickle.load(f)
902
- self.gene_id_dict = invert_dict(self.id_gene_dict)
903
-
904
- def ens_to_token(self, ens_id):
905
- if not self.raise_errors:
906
- return self.gene_token_dict.get(ens_id, ens_id)
907
- else:
908
- return self.gene_token_dict[ens_id]
909
-
910
- def token_to_ens(self, token):
911
- if not self.raise_errors:
912
- return self.token_gene_dict.get(token, token)
913
- else:
914
- return self.token_gene_dict[token]
915
-
916
- def ens_to_symbol(self, ens_id):
917
- if not self.raise_errors:
918
- return self.gene_id_dict.get(ens_id, ens_id)
919
- else:
920
- return self.gene_id_dict[ens_id]
921
-
922
- def symbol_to_ens(self, symbol):
923
- if not self.raise_errors:
924
- return self.id_gene_dict.get(symbol, symbol)
925
- else:
926
- return self.id_gene_dict[symbol]
927
-
928
- def token_to_symbol(self, token):
929
- return self.ens_to_symbol(self.token_to_ens(token))
930
-
931
- def symbol_to_token(self, symbol):
932
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
17
  BitsAndBytesConfig,
18
  )
19
 
 
 
 
 
 
20
  logger = logging.getLogger(__name__)
21
 
22
 
 
122
  output_hidden_states = (mode == "eval")
123
 
124
  # Quantization logic
125
+ if isinstance(quantize, dict):
126
+ quantize_config = quantize.get("bnb_config", None)
127
+ peft_config = quantize.get("peft_config", None)
128
+ elif quantize:
129
  if inference_only:
130
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
131
  peft_config = None
 
136
  bnb_4bit_quant_type="nf4",
137
  bnb_4bit_compute_dtype=torch.bfloat16,
138
  )
 
 
 
 
 
 
 
 
139
  try:
140
+ peft_config = LoraConfig(
141
+ lora_alpha=128,
142
+ lora_dropout=0.1,
143
+ r=64,
144
+ bias="none",
145
+ task_type="TokenClassification",
146
+ )
147
+ except ValueError as e:
148
+ peft_config = LoraConfig(
149
+ lora_alpha=128,
150
+ lora_dropout=0.1,
151
+ r=64,
152
+ bias="none",
153
+ task_type="TOKEN_CLS",
154
+ )
155
  else:
156
  quantize_config = None
157
  peft_config = None
 
188
  model.eval()
189
 
190
  # Handle device placement and PEFT
191
+ adapter_config_path = os.path.join(model_directory, "adapter_config.json")
192
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
193
  if not quantize:
194
  # Only move non-quantized models
195
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196
  model = model.to(device)
197
+ elif os.path.exists(adapter_config_path):
198
+ # If adapter files exist, load them into the model using PEFT's from_pretrained
199
+ model = PeftModel.from_pretrained(model, model_directory)
200
+ model = model.to(device)
201
+ print("loading lora weights")
202
  elif peft_config:
203
  # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
204
  model.enable_input_require_grads()
205
  model = get_peft_model(model, peft_config)
206
+ model = model.to(device)
207
 
208
  return model
209
 
 
892
  "'goal_state': 'nf', "
893
  "'alt_states': ['hcm', 'other1', 'other2']}"
894
  )
895
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/token_dictionary_gc95M.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:67c445f4385127adfc48dcc072320cd65d6822829bf27dd38070e6e787bc597f
3
- size 425590
 
 
 
 
geneformer/tokenizer.py CHANGED
@@ -32,9 +32,7 @@ Geneformer tokenizer.
32
 
33
  | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
 
35
- | OF NOTE: Take care that the correct token dictionary and gene median file is used for the correct model.
36
-
37
- | OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048.
38
 
39
  """
40
 
@@ -299,6 +297,7 @@ class TranscriptomeTokenizer:
299
  model_input_size=4096,
300
  special_token=True,
301
  collapse_gene_ids=True,
 
302
  gene_median_file=GENE_MEDIAN_FILE,
303
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
304
  gene_mapping_file=ENSEMBL_MAPPING_FILE,
@@ -318,15 +317,18 @@ class TranscriptomeTokenizer:
318
  | Chunk size for anndata tokenizer.
319
  model_input_size : int = 4096
320
  | Max input size of model to truncate input to.
321
- | For the 30M model series, should be 2048. For the 95M model series, should be 4096.
322
  special_token : bool = True
323
  | Adds CLS token before and EOS token after rank value encoding.
324
- | For the 30M model series, should be False. For the 95M model series, should be True.
325
  collapse_gene_ids : bool = True
326
  | Whether to collapse gene IDs based on gene mapping dictionary.
 
 
 
327
  gene_median_file : Path
328
  | Path to pickle file containing dictionary of non-zero median
329
- | gene expression values across Genecorpus-30M.
330
  token_dictionary_file : Path
331
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
332
  gene_mapping_file : None, Path
@@ -348,8 +350,22 @@ class TranscriptomeTokenizer:
348
  # add CLS and EOS tokens
349
  self.special_token = special_token
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  # load dictionary of gene normalization factors
352
- # (non-zero median value of expression across Genecorpus-30M)
353
  with open(gene_median_file, "rb") as f:
354
  self.gene_median_dict = pickle.load(f)
355
 
@@ -372,7 +388,7 @@ class TranscriptomeTokenizer:
372
  "<eos>" in self.gene_token_dict.keys()
373
  ):
374
  logger.warning(
375
- "<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for 95M model series, special_token should be True."
376
  )
377
 
378
  # if collapsing duplicate gene IDs
 
32
 
33
  | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
 
35
+ | OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.)
 
 
36
 
37
  """
38
 
 
297
  model_input_size=4096,
298
  special_token=True,
299
  collapse_gene_ids=True,
300
+ model_version="V2",
301
  gene_median_file=GENE_MEDIAN_FILE,
302
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
303
  gene_mapping_file=ENSEMBL_MAPPING_FILE,
 
317
  | Chunk size for anndata tokenizer.
318
  model_input_size : int = 4096
319
  | Max input size of model to truncate input to.
320
+ | For the V1 model series, should be 2048. For the V2 model series, should be 4096.
321
  special_token : bool = True
322
  | Adds CLS token before and EOS token after rank value encoding.
323
+ | For the V1 model series, should be False. For the V2 model series, should be True.
324
  collapse_gene_ids : bool = True
325
  | Whether to collapse gene IDs based on gene mapping dictionary.
326
+ model_version : str
327
+ | To auto-select settings for model version other than current default.
328
+ | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
329
  gene_median_file : Path
330
  | Path to pickle file containing dictionary of non-zero median
331
+ | gene expression values across Genecorpus.
332
  token_dictionary_file : Path
333
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
334
  gene_mapping_file : None, Path
 
350
  # add CLS and EOS tokens
351
  self.special_token = special_token
352
 
353
+ # CHANGE DEFAULTS TO BE FOR MODEL OTHER THAN CURRENT
354
+ self.model_version = model_version
355
+ if self.model_version not in ["V1","V2"]:
356
+ logger.error(
357
+ "Unrecognized model version. Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells."
358
+ )
359
+ elif self.model_version == "V1":
360
+ self.model_input_size = 2048
361
+ self.special_token = False
362
+ from . import ENSEMBL_MAPPING_FILE_30M, GENE_MEDIAN_FILE_30M, TOKEN_DICTIONARY_FILE_30M
363
+ gene_median_file = GENE_MEDIAN_FILE_30M
364
+ token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
365
+ gene_mapping_file = ENSEMBL_MAPPING_FILE_30M
366
+
367
  # load dictionary of gene normalization factors
368
+ # (non-zero median value of expression across Genecorpus)
369
  with open(gene_median_file, "rb") as f:
370
  self.gene_median_dict = pickle.load(f)
371
 
 
388
  "<eos>" in self.gene_token_dict.keys()
389
  ):
390
  logger.warning(
391
+ "<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for V2 model series, special_token should be True."
392
  )
393
 
394
  # if collapsing duplicate gene IDs
generation_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
- "transformers_version": "4.37.1"
5
  }
 
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
+ "transformers_version": "4.44.2"
5
  }
gf-12L-30M-i2048/config.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "gradient_checkpointing": false,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 2048,
14
- "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "transformers_version": "4.6.0",
20
- "type_vocab_size": 2,
21
- "use_cache": true,
22
- "vocab_size": 25426
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-12L-30M-i2048/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615
3
- size 158467410
 
 
 
 
gf-12L-30M-i2048/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:259cf6067211e24e198690d00f0a222ee5550ad57e23d04ced0d0ca2e1b3738e
3
- size 2607
 
 
 
 
gf-12L-95M-i4096/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
- "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.1",
21
- "type_vocab_size": 2,
22
- "use_cache": true,
23
- "vocab_size": 20275
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-12L-95M-i4096/generation_config.json DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "pad_token_id": 0,
4
- "transformers_version": "4.37.1"
5
- }
 
 
 
 
 
 
gf-12L-95M-i4096/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c
3
- size 152012980
 
 
 
 
gf-12L-95M-i4096/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d
3
- size 4920
 
 
 
 
gf-12L-95M-i4096_CLcancer/config.json DELETED
@@ -1,25 +0,0 @@
1
- {
2
- "_name_or_path": "/gladstone/theodoris/lab/pretrained_models/encoder/240402_194213_geneformer_94M_L12_emb512_SL4096_E3_B4_LR0.0005_LScosine_WU5000_Oadamw_DS8/models",
3
- "architectures": [
4
- "BertForMaskedLM"
5
- ],
6
- "attention_probs_dropout_prob": 0.02,
7
- "classifier_dropout": null,
8
- "hidden_act": "relu",
9
- "hidden_dropout_prob": 0.02,
10
- "hidden_size": 512,
11
- "initializer_range": 0.02,
12
- "intermediate_size": 1024,
13
- "layer_norm_eps": 1e-12,
14
- "max_position_embeddings": 4096,
15
- "model_type": "bert",
16
- "num_attention_heads": 8,
17
- "num_hidden_layers": 12,
18
- "pad_token_id": 0,
19
- "position_embedding_type": "absolute",
20
- "torch_dtype": "float32",
21
- "transformers_version": "4.37.1",
22
- "type_vocab_size": 2,
23
- "use_cache": true,
24
- "vocab_size": 20275
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-12L-95M-i4096_CLcancer/generation_config.json DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "pad_token_id": 0,
4
- "transformers_version": "4.37.1"
5
- }
 
 
 
 
 
 
gf-12L-95M-i4096_CLcancer/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2451adeed240c165634fea60ccba17063da8a2843ea9fcdcc0ce185720bf0dc2
3
- size 152012980
 
 
 
 
gf-12L-95M-i4096_CLcancer/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:37074f3ea62a6ba0a312c38526c20c2dccbb068a2c7ee8c7c73b435dd90ab7b1
3
- size 5048
 
 
 
 
gf-20L-95M-i4096/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 896,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1792,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
- "model_type": "bert",
15
- "num_attention_heads": 14,
16
- "num_hidden_layers": 20,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.1",
21
- "type_vocab_size": 2,
22
- "use_cache": true,
23
- "vocab_size": 20275
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-20L-95M-i4096/generation_config.json DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "pad_token_id": 0,
4
- "transformers_version": "4.37.1"
5
- }
 
 
 
 
 
 
gf-20L-95M-i4096/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:db85c081a6d392448955c7d0185e26aba74507518df991ca8c69ee9108ce8bbf
3
- size 605292732
 
 
 
 
gf-20L-95M-i4096/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5afed602918d6f0c4916c1b9335bcdb619bca2c6fd6c7e0dd2a86d195264b8cc
3
- size 5048
 
 
 
 
gf-6L-30M-i2048/config.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "gradient_checkpointing": false,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 256,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 512,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 2048,
14
- "model_type": "bert",
15
- "num_attention_heads": 4,
16
- "num_hidden_layers": 6,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "transformers_version": "4.6.0",
20
- "type_vocab_size": 2,
21
- "use_cache": true,
22
- "vocab_size": 25426
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-6L-30M-i2048/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a5e33a757431643b3697de7ef6127950cdc49e06e58d4266b3a3ab191b683f14
3
- size 41183536
 
 
 
 
gf-6L-30M-i2048/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8d860e2125884475dd42bc2cd9a0e60c60808a7351241e08f2154931ffc142da
3
- size 41216562
 
 
 
 
gf-6L-30M-i2048/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f0ec3459454205174c9d2e4d6c6930f6b0fbf3364fc03a6f4d99c4d3add2012b
3
- size 2607