Christina Theodoris
commited on
Commit
·
d319fef
1
Parent(s):
31bf641
update with V2 models
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- MANIFEST.in +9 -4
- README.md +9 -16
- config.json +7 -7
- examples/cell_classification.ipynb +6 -7
- examples/extract_and_plot_cell_embeddings.ipynb +6 -7
- examples/gene_classification.ipynb +10 -9
- examples/in_silico_perturbation.ipynb +10 -13
- examples/multitask_cell_classification.ipynb +3 -3
- examples/tokenizing_scRNAseq_data.ipynb +4 -8
- fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +0 -24
- fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +0 -3
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json +0 -35
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt +0 -3
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin +0 -3
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth +0 -3
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt +0 -3
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json +0 -150
- fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin +0 -3
- geneformer/__init__.py +9 -4
- geneformer/classifier.py +13 -2
- geneformer/emb_extractor.py +15 -0
- geneformer/ensembl_mapping_dict_gc95M.pkl +0 -3
- geneformer/evaluation_utils.py +3 -8
- geneformer/gene_median_dictionary_gc95M.pkl +0 -3
- geneformer/gene_name_id_dict_gc95M.pkl +0 -3
- geneformer/in_silico_perturber.py +20 -0
- geneformer/in_silico_perturber_stats.py +12 -1
- geneformer/perturber_utils.py +28 -65
- geneformer/token_dictionary_gc95M.pkl +0 -3
- geneformer/tokenizer.py +24 -8
- generation_config.json +1 -1
- gf-12L-30M-i2048/config.json +0 -23
- gf-12L-30M-i2048/pytorch_model.bin +0 -3
- gf-12L-30M-i2048/training_args.bin +0 -3
- gf-12L-95M-i4096/config.json +0 -24
- gf-12L-95M-i4096/generation_config.json +0 -5
- gf-12L-95M-i4096/model.safetensors +0 -3
- gf-12L-95M-i4096/training_args.bin +0 -3
- gf-12L-95M-i4096_CLcancer/config.json +0 -25
- gf-12L-95M-i4096_CLcancer/generation_config.json +0 -5
- gf-12L-95M-i4096_CLcancer/model.safetensors +0 -3
- gf-12L-95M-i4096_CLcancer/training_args.bin +0 -3
- gf-20L-95M-i4096/config.json +0 -24
- gf-20L-95M-i4096/generation_config.json +0 -5
- gf-20L-95M-i4096/model.safetensors +0 -3
- gf-20L-95M-i4096/training_args.bin +0 -3
- gf-6L-30M-i2048/config.json +0 -23
- gf-6L-30M-i2048/model.safetensors +0 -3
- gf-6L-30M-i2048/pytorch_model.bin +0 -3
- gf-6L-30M-i2048/training_args.bin +0 -3
MANIFEST.in
CHANGED
@@ -1,4 +1,9 @@
|
|
1 |
-
include geneformer/
|
2 |
-
include geneformer/
|
3 |
-
include geneformer/
|
4 |
-
include geneformer/
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include geneformer/gene_median_dictionary_gc104m.pkl
|
2 |
+
include geneformer/gene_name_id_dict_gc104m.pkl
|
3 |
+
include geneformer/ensembl_mapping_dict_gc104m.pkl
|
4 |
+
include geneformer/token_dictionary_gc104m.pkl
|
5 |
+
|
6 |
+
include geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30m.pkl
|
7 |
+
include geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30m.pkl
|
8 |
+
include geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30m.pkl
|
9 |
+
include geneformer/gene_dictionaries_30m/token_dictionary_gc30m.pkl
|
README.md
CHANGED
@@ -9,35 +9,28 @@ tags:
|
|
9 |
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
10 |
|
11 |
- See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
|
12 |
-
- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~
|
13 |
- See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
|
14 |
|
15 |
# Model Description
|
16 |
-
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation.
|
17 |
|
18 |
-
Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
19 |
|
20 |
The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
21 |
|
22 |
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
23 |
|
24 |
-
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational
|
25 |
|
26 |
The repository includes the following pretrained models:
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
i=input size\
|
31 |
-
(pretraining date)
|
32 |
|
33 |
-
|
34 |
-
- GF-12L-30M-i2048 (June 2021)
|
35 |
-
- GF-12L-95M-i4096 (April 2024)
|
36 |
-
- GF-20L-95M-i4096 (April 2024)
|
37 |
|
38 |
-
The
|
39 |
-
|
40 |
-
The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer.
|
41 |
|
42 |
# Application
|
43 |
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
@@ -87,7 +80,7 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main
|
|
87 |
|
88 |
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
89 |
|
90 |
-
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.).
|
91 |
|
92 |
# Citations
|
93 |
- C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
|
|
|
9 |
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
10 |
|
11 |
- See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies.
|
12 |
+
- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model, now trained on ~104 million transcriptomes, and our continual learning, multitask learning, and quantization strategies.
|
13 |
- See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation.
|
14 |
|
15 |
# Model Description
|
16 |
+
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer V1 was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million human single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. The current updated Geneformer V2 is pretrained on ~104 million human single cell transcriptomes (non-cancer). The cancer continual learning V2 variant was continually pretrrained on ~14 million cancer transcriptomes to yield a cancer domain-tuned model.
|
17 |
|
18 |
+
Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus (-30M for V1, -104M for V2). The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
19 |
|
20 |
The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
21 |
|
22 |
We detail applications and results in [our manuscript](https://rdcu.be/ddrx0).
|
23 |
|
24 |
+
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational AI model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets.
|
25 |
|
26 |
The repository includes the following pretrained models:
|
27 |
|
28 |
+
- Geneformer-V1-10M: original model trained June 2021 on ~30M human single cell transcriptomes, 10M parameters, input size 2048, vocabulary ~25K protein-coding or non-coding RNA genes
|
29 |
+
- Geneformer-V2-104M and Geneformer-V2-316M: updated model trained Dec 2024 on ~104M human single cell transcriptomes, 104M or 316M parameters, input size 4096, vocabulary ~20K protein-coding genes
|
|
|
|
|
30 |
|
31 |
+
The current default model in the main directory of the repository is Geneformer-V2-316M.
|
|
|
|
|
|
|
32 |
|
33 |
+
The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, Geneformer-V2-104M_CLcancer.
|
|
|
|
|
34 |
|
35 |
# Application
|
36 |
The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
|
|
80 |
|
81 |
Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications.
|
82 |
|
83 |
+
Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). Importantly, as usual for deep learning models, there are no uniformly applicable default hyperparameters for Geneformer.
|
84 |
|
85 |
# Citations
|
86 |
- C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors)
|
config.json
CHANGED
@@ -2,22 +2,22 @@
|
|
2 |
"architectures": [
|
3 |
"BertForMaskedLM"
|
4 |
],
|
5 |
-
"attention_probs_dropout_prob": 0.
|
6 |
"classifier_dropout": null,
|
7 |
"hidden_act": "relu",
|
8 |
-
"hidden_dropout_prob": 0.
|
9 |
-
"hidden_size":
|
10 |
"initializer_range": 0.02,
|
11 |
-
"intermediate_size":
|
12 |
"layer_norm_eps": 1e-12,
|
13 |
"max_position_embeddings": 4096,
|
14 |
"model_type": "bert",
|
15 |
-
"num_attention_heads":
|
16 |
-
"num_hidden_layers":
|
17 |
"pad_token_id": 0,
|
18 |
"position_embedding_type": "absolute",
|
19 |
"torch_dtype": "float32",
|
20 |
-
"transformers_version": "4.
|
21 |
"type_vocab_size": 2,
|
22 |
"use_cache": true,
|
23 |
"vocab_size": 20275
|
|
|
2 |
"architectures": [
|
3 |
"BertForMaskedLM"
|
4 |
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
"classifier_dropout": null,
|
7 |
"hidden_act": "relu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 1152,
|
10 |
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 4608,
|
12 |
"layer_norm_eps": 1e-12,
|
13 |
"max_position_embeddings": 4096,
|
14 |
"model_type": "bert",
|
15 |
+
"num_attention_heads": 18,
|
16 |
+
"num_hidden_layers": 18,
|
17 |
"pad_token_id": 0,
|
18 |
"position_embedding_type": "absolute",
|
19 |
"torch_dtype": "float32",
|
20 |
+
"transformers_version": "4.44.2",
|
21 |
"type_vocab_size": 2,
|
22 |
"use_cache": true,
|
23 |
"vocab_size": 20275
|
examples/cell_classification.ipynb
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
"id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
|
14 |
"metadata": {},
|
15 |
"source": [
|
16 |
-
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
|
17 |
]
|
18 |
},
|
19 |
{
|
@@ -69,9 +69,7 @@
|
|
69 |
" \"seed\": 73,\n",
|
70 |
"}\n",
|
71 |
"\n",
|
72 |
-
"# OF NOTE:
|
73 |
-
"# (otherwise the Classifier will use the current default model dictionary)\n",
|
74 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
75 |
"cc = Classifier(classifier=\"cell\",\n",
|
76 |
" cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
|
77 |
" filter_data=filter_data_dict,\n",
|
@@ -80,6 +78,7 @@
|
|
80 |
" freeze_layers = 2,\n",
|
81 |
" num_crossval_splits = 1,\n",
|
82 |
" forward_batch_size=200,\n",
|
|
|
83 |
" nproc=16)"
|
84 |
]
|
85 |
},
|
@@ -264,8 +263,8 @@
|
|
264 |
" \"train\": train_ids,\n",
|
265 |
" \"eval\": eval_ids}\n",
|
266 |
"\n",
|
267 |
-
"#
|
268 |
-
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\"
|
269 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
|
270 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
271 |
" output_directory=output_dir,\n",
|
@@ -450,7 +449,7 @@
|
|
450 |
"name": "python",
|
451 |
"nbconvert_exporter": "python",
|
452 |
"pygments_lexer": "ipython3",
|
453 |
-
"version": "3.10.
|
454 |
}
|
455 |
},
|
456 |
"nbformat": 4,
|
|
|
13 |
"id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
|
14 |
"metadata": {},
|
15 |
"source": [
|
16 |
+
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters."
|
17 |
]
|
18 |
},
|
19 |
{
|
|
|
69 |
" \"seed\": 73,\n",
|
70 |
"}\n",
|
71 |
"\n",
|
72 |
+
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
|
|
|
|
73 |
"cc = Classifier(classifier=\"cell\",\n",
|
74 |
" cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
|
75 |
" filter_data=filter_data_dict,\n",
|
|
|
78 |
" freeze_layers = 2,\n",
|
79 |
" num_crossval_splits = 1,\n",
|
80 |
" forward_batch_size=200,\n",
|
81 |
+
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
82 |
" nproc=16)"
|
83 |
]
|
84 |
},
|
|
|
263 |
" \"train\": train_ids,\n",
|
264 |
" \"eval\": eval_ids}\n",
|
265 |
"\n",
|
266 |
+
"# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
|
267 |
+
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
|
268 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
|
269 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
270 |
" output_directory=output_dir,\n",
|
|
|
449 |
"name": "python",
|
450 |
"nbconvert_exporter": "python",
|
451 |
"pygments_lexer": "ipython3",
|
452 |
+
"version": "3.10.13"
|
453 |
}
|
454 |
},
|
455 |
"nbformat": 4,
|
examples/extract_and_plot_cell_embeddings.ipynb
CHANGED
@@ -18,8 +18,7 @@
|
|
18 |
"outputs": [],
|
19 |
"source": [
|
20 |
"# initiate EmbExtractor\n",
|
21 |
-
"# OF NOTE:
|
22 |
-
"# (otherwise the EmbExtractor will use the current default model dictionary)\n",
|
23 |
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
24 |
" num_classes=3,\n",
|
25 |
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
@@ -28,13 +27,13 @@
|
|
28 |
" emb_label=[\"disease\",\"cell_type\"],\n",
|
29 |
" labels_to_plot=[\"disease\"],\n",
|
30 |
" forward_batch_size=200,\n",
|
31 |
-
"
|
32 |
-
"
|
33 |
"\n",
|
34 |
"# extracts embedding from input data\n",
|
35 |
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
36 |
-
"# example dataset for
|
37 |
-
"embs = embex.extract_embs(\"../fine_tuned_models/
|
38 |
" \"path/to/input_data/\",\n",
|
39 |
" \"path/to/output_directory/\",\n",
|
40 |
" \"output_prefix\")\n"
|
@@ -132,7 +131,7 @@
|
|
132 |
"name": "python",
|
133 |
"nbconvert_exporter": "python",
|
134 |
"pygments_lexer": "ipython3",
|
135 |
-
"version": "3.10.
|
136 |
}
|
137 |
},
|
138 |
"nbformat": 4,
|
|
|
18 |
"outputs": [],
|
19 |
"source": [
|
20 |
"# initiate EmbExtractor\n",
|
21 |
+
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
|
|
22 |
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
23 |
" num_classes=3,\n",
|
24 |
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
|
|
27 |
" emb_label=[\"disease\",\"cell_type\"],\n",
|
28 |
" labels_to_plot=[\"disease\"],\n",
|
29 |
" forward_batch_size=200,\n",
|
30 |
+
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
31 |
+
" nproc=16)\n",
|
32 |
"\n",
|
33 |
"# extracts embedding from input data\n",
|
34 |
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
35 |
+
"# example dataset for V1 model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
|
36 |
+
"embs = embex.extract_embs(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
|
37 |
" \"path/to/input_data/\",\n",
|
38 |
" \"path/to/output_directory/\",\n",
|
39 |
" \"output_prefix\")\n"
|
|
|
131 |
"name": "python",
|
132 |
"nbconvert_exporter": "python",
|
133 |
"pygments_lexer": "ipython3",
|
134 |
+
"version": "3.10.13"
|
135 |
}
|
136 |
},
|
137 |
"nbformat": 4,
|
examples/gene_classification.ipynb
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
"id": "79539e95-2c9c-4162-835c-f0d158abb15d",
|
14 |
"metadata": {},
|
15 |
"source": [
|
16 |
-
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
|
17 |
]
|
18 |
},
|
19 |
{
|
@@ -71,15 +71,14 @@
|
|
71 |
}
|
72 |
],
|
73 |
"source": [
|
74 |
-
"# OF NOTE:
|
75 |
-
"# (otherwise the Classifier will use the current default model dictionary)\n",
|
76 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
77 |
"cc = Classifier(classifier=\"gene\",\n",
|
78 |
" gene_class_dict = gene_class_dict,\n",
|
79 |
" max_ncells = 10_000,\n",
|
80 |
" freeze_layers = 4,\n",
|
81 |
" num_crossval_splits = 5,\n",
|
82 |
" forward_batch_size=200,\n",
|
|
|
83 |
" nproc=16)"
|
84 |
]
|
85 |
},
|
@@ -843,8 +842,8 @@
|
|
843 |
}
|
844 |
],
|
845 |
"source": [
|
846 |
-
"#
|
847 |
-
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\"
|
848 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
849 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
850 |
" output_directory=output_dir,\n",
|
@@ -1066,12 +1065,14 @@
|
|
1066 |
}
|
1067 |
],
|
1068 |
"source": [
|
|
|
1069 |
"cc = Classifier(classifier=\"gene\",\n",
|
1070 |
" gene_class_dict = gene_class_dict,\n",
|
1071 |
" max_ncells = 10_000,\n",
|
1072 |
" freeze_layers = 4,\n",
|
1073 |
" num_crossval_splits = 0,\n",
|
1074 |
" forward_batch_size=200,\n",
|
|
|
1075 |
" nproc=16)"
|
1076 |
]
|
1077 |
},
|
@@ -1218,8 +1219,8 @@
|
|
1218 |
}
|
1219 |
],
|
1220 |
"source": [
|
1221 |
-
"#
|
1222 |
-
"trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\"
|
1223 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
1224 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
1225 |
" output_directory=output_dir,\n",
|
@@ -1243,7 +1244,7 @@
|
|
1243 |
"name": "python",
|
1244 |
"nbconvert_exporter": "python",
|
1245 |
"pygments_lexer": "ipython3",
|
1246 |
-
"version": "3.10.
|
1247 |
}
|
1248 |
},
|
1249 |
"nbformat": 4,
|
|
|
13 |
"id": "79539e95-2c9c-4162-835c-f0d158abb15d",
|
14 |
"metadata": {},
|
15 |
"source": [
|
16 |
+
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters."
|
17 |
]
|
18 |
},
|
19 |
{
|
|
|
71 |
}
|
72 |
],
|
73 |
"source": [
|
74 |
+
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
|
|
|
|
75 |
"cc = Classifier(classifier=\"gene\",\n",
|
76 |
" gene_class_dict = gene_class_dict,\n",
|
77 |
" max_ncells = 10_000,\n",
|
78 |
" freeze_layers = 4,\n",
|
79 |
" num_crossval_splits = 5,\n",
|
80 |
" forward_batch_size=200,\n",
|
81 |
+
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
82 |
" nproc=16)"
|
83 |
]
|
84 |
},
|
|
|
842 |
}
|
843 |
],
|
844 |
"source": [
|
845 |
+
"# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
|
846 |
+
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
|
847 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
848 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
849 |
" output_directory=output_dir,\n",
|
|
|
1065 |
}
|
1066 |
],
|
1067 |
"source": [
|
1068 |
+
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
1069 |
"cc = Classifier(classifier=\"gene\",\n",
|
1070 |
" gene_class_dict = gene_class_dict,\n",
|
1071 |
" max_ncells = 10_000,\n",
|
1072 |
" freeze_layers = 4,\n",
|
1073 |
" num_crossval_splits = 0,\n",
|
1074 |
" forward_batch_size=200,\n",
|
1075 |
+
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
1076 |
" nproc=16)"
|
1077 |
]
|
1078 |
},
|
|
|
1219 |
}
|
1220 |
],
|
1221 |
"source": [
|
1222 |
+
"# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n",
|
1223 |
+
"trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n",
|
1224 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
1225 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
1226 |
" output_directory=output_dir,\n",
|
|
|
1244 |
"name": "python",
|
1245 |
"nbconvert_exporter": "python",
|
1246 |
"pygments_lexer": "ipython3",
|
1247 |
+
"version": "3.10.13"
|
1248 |
}
|
1249 |
},
|
1250 |
"nbformat": 4,
|
examples/in_silico_perturbation.ipynb
CHANGED
@@ -39,9 +39,7 @@
|
|
39 |
"\n",
|
40 |
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
41 |
"\n",
|
42 |
-
"# OF NOTE:
|
43 |
-
"# (otherwise the EmbExtractor will use the current default model dictionary)\n",
|
44 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
45 |
"embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
46 |
" num_classes=3,\n",
|
47 |
" filter_data=filter_data_dict,\n",
|
@@ -49,6 +47,7 @@
|
|
49 |
" emb_layer=0,\n",
|
50 |
" summary_stat=\"exact_mean\",\n",
|
51 |
" forward_batch_size=256,\n",
|
|
|
52 |
" nproc=16)\n",
|
53 |
"\n",
|
54 |
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
@@ -67,9 +66,7 @@
|
|
67 |
},
|
68 |
"outputs": [],
|
69 |
"source": [
|
70 |
-
"# OF NOTE:
|
71 |
-
"# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
|
72 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
73 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
74 |
" perturb_rank_shift=None,\n",
|
75 |
" genes_to_perturb=\"all\",\n",
|
@@ -77,7 +74,7 @@
|
|
77 |
" anchor_gene=None,\n",
|
78 |
" model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
79 |
" num_classes=3,\n",
|
80 |
-
" emb_mode=\"cell\"
|
81 |
" cell_emb_style=\"mean_pool\",\n",
|
82 |
" filter_data=filter_data_dict,\n",
|
83 |
" cell_states_to_model=cell_states_to_model,\n",
|
@@ -85,6 +82,7 @@
|
|
85 |
" max_ncells=2000,\n",
|
86 |
" emb_layer=0,\n",
|
87 |
" forward_batch_size=400,\n",
|
|
|
88 |
" nproc=16)"
|
89 |
]
|
90 |
},
|
@@ -97,7 +95,7 @@
|
|
97 |
"source": [
|
98 |
"# outputs intermediate files from in silico perturbation\n",
|
99 |
"\n",
|
100 |
-
"isp.perturb_data(\"../fine_tuned_models/
|
101 |
" \"path/to/input_data\",\n",
|
102 |
" \"path/to/isp_output_directory\",\n",
|
103 |
" \"output_prefix\")"
|
@@ -110,14 +108,13 @@
|
|
110 |
"metadata": {},
|
111 |
"outputs": [],
|
112 |
"source": [
|
113 |
-
"# OF NOTE:
|
114 |
-
"# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
|
115 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
116 |
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
117 |
" genes_perturbed=\"all\",\n",
|
118 |
" combos=0,\n",
|
119 |
" anchor_gene=None,\n",
|
120 |
-
" cell_states_to_model=cell_states_to_model
|
|
|
121 |
]
|
122 |
},
|
123 |
{
|
@@ -151,7 +148,7 @@
|
|
151 |
"name": "python",
|
152 |
"nbconvert_exporter": "python",
|
153 |
"pygments_lexer": "ipython3",
|
154 |
-
"version": "3.10.
|
155 |
}
|
156 |
},
|
157 |
"nbformat": 4,
|
|
|
39 |
"\n",
|
40 |
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
41 |
"\n",
|
42 |
+
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
|
|
|
|
43 |
"embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
44 |
" num_classes=3,\n",
|
45 |
" filter_data=filter_data_dict,\n",
|
|
|
47 |
" emb_layer=0,\n",
|
48 |
" summary_stat=\"exact_mean\",\n",
|
49 |
" forward_batch_size=256,\n",
|
50 |
+
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
51 |
" nproc=16)\n",
|
52 |
"\n",
|
53 |
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
|
|
66 |
},
|
67 |
"outputs": [],
|
68 |
"source": [
|
69 |
+
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
|
|
|
|
70 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
71 |
" perturb_rank_shift=None,\n",
|
72 |
" genes_to_perturb=\"all\",\n",
|
|
|
74 |
" anchor_gene=None,\n",
|
75 |
" model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
76 |
" num_classes=3,\n",
|
77 |
+
" emb_mode=\"cell\", # OF NOTE: SET TO \"CELL\" FOR V1 MODEL. FOR V2, SHOULD BE \"CLS\" (current default).\n",
|
78 |
" cell_emb_style=\"mean_pool\",\n",
|
79 |
" filter_data=filter_data_dict,\n",
|
80 |
" cell_states_to_model=cell_states_to_model,\n",
|
|
|
82 |
" max_ncells=2000,\n",
|
83 |
" emb_layer=0,\n",
|
84 |
" forward_batch_size=400,\n",
|
85 |
+
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n",
|
86 |
" nproc=16)"
|
87 |
]
|
88 |
},
|
|
|
95 |
"source": [
|
96 |
"# outputs intermediate files from in silico perturbation\n",
|
97 |
"\n",
|
98 |
+
"isp.perturb_data(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n",
|
99 |
" \"path/to/input_data\",\n",
|
100 |
" \"path/to/isp_output_directory\",\n",
|
101 |
" \"output_prefix\")"
|
|
|
108 |
"metadata": {},
|
109 |
"outputs": [],
|
110 |
"source": [
|
111 |
+
"# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n",
|
|
|
|
|
112 |
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
113 |
" genes_perturbed=\"all\",\n",
|
114 |
" combos=0,\n",
|
115 |
" anchor_gene=None,\n",
|
116 |
+
" cell_states_to_model=cell_states_to_model,\n",
|
117 |
+
" model_version=\"V1\", # OF NOTE: SET TO V1 MODEL SINCE V1 WAS USED FOR IN SILICO PERTURBATION ABOVE)"
|
118 |
]
|
119 |
},
|
120 |
{
|
|
|
148 |
"name": "python",
|
149 |
"nbconvert_exporter": "python",
|
150 |
"pygments_lexer": "ipython3",
|
151 |
+
"version": "3.10.13"
|
152 |
}
|
153 |
},
|
154 |
"nbformat": 4,
|
examples/multitask_cell_classification.ipynb
CHANGED
@@ -286,7 +286,7 @@
|
|
286 |
" filter_data_dict=filter_data_dict,\n",
|
287 |
" max_ncells=1000, # Number of cells to extract embeddings for\n",
|
288 |
" emb_layer=0, # Use the second to last layer\n",
|
289 |
-
" emb_mode = \"cls\"
|
290 |
" summary_stat=\"exact_mean\",\n",
|
291 |
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
292 |
" nproc=4\n",
|
@@ -324,7 +324,7 @@
|
|
324 |
" perturb_type=perturb_type,\n",
|
325 |
" genes_to_perturb=\"all\", # Perturb all genes\n",
|
326 |
" model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
|
327 |
-
" emb_mode=\"cls\", # Use CLS token embedding\n",
|
328 |
" cell_states_to_model=cell_states_to_model,\n",
|
329 |
" state_embs_dict=state_embs_dict,\n",
|
330 |
" max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
|
@@ -412,7 +412,7 @@
|
|
412 |
"name": "python",
|
413 |
"nbconvert_exporter": "python",
|
414 |
"pygments_lexer": "ipython3",
|
415 |
-
"version": "3.
|
416 |
}
|
417 |
},
|
418 |
"nbformat": 4,
|
|
|
286 |
" filter_data_dict=filter_data_dict,\n",
|
287 |
" max_ncells=1000, # Number of cells to extract embeddings for\n",
|
288 |
" emb_layer=0, # Use the second to last layer\n",
|
289 |
+
" emb_mode = \"cls\", # Use CLS token embedding for V2 model\n",
|
290 |
" summary_stat=\"exact_mean\",\n",
|
291 |
" forward_batch_size=8, # Adjust based on available GPU memory\n",
|
292 |
" nproc=4\n",
|
|
|
324 |
" perturb_type=perturb_type,\n",
|
325 |
" genes_to_perturb=\"all\", # Perturb all genes\n",
|
326 |
" model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
|
327 |
+
" emb_mode=\"cls\", # Use CLS token embedding for V2 model\n",
|
328 |
" cell_states_to_model=cell_states_to_model,\n",
|
329 |
" state_embs_dict=state_embs_dict,\n",
|
330 |
" max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
|
|
|
412 |
"name": "python",
|
413 |
"nbconvert_exporter": "python",
|
414 |
"pygments_lexer": "ipython3",
|
415 |
+
"version": "3.10.13"
|
416 |
}
|
417 |
},
|
418 |
"nbformat": 4,
|
examples/tokenizing_scRNAseq_data.ipynb
CHANGED
@@ -34,12 +34,8 @@
|
|
34 |
"metadata": {},
|
35 |
"source": [
|
36 |
"**********************************************************************************************************\n",
|
37 |
-
"#### OF NOTE:
|
38 |
-
"####
|
39 |
-
"\n",
|
40 |
-
"#### ADDITIONALLY:\n",
|
41 |
-
"#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
|
42 |
-
"#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
|
43 |
]
|
44 |
},
|
45 |
{
|
@@ -59,7 +55,7 @@
|
|
59 |
"metadata": {},
|
60 |
"outputs": [],
|
61 |
"source": [
|
62 |
-
"tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
|
63 |
"tk.tokenize_data(\"loom_data_directory\", \n",
|
64 |
" \"output_directory\", \n",
|
65 |
" \"output_prefix\", \n",
|
@@ -83,7 +79,7 @@
|
|
83 |
"name": "python",
|
84 |
"nbconvert_exporter": "python",
|
85 |
"pygments_lexer": "ipython3",
|
86 |
-
"version": "3.10.
|
87 |
}
|
88 |
},
|
89 |
"nbformat": 4,
|
|
|
34 |
"metadata": {},
|
35 |
"source": [
|
36 |
"**********************************************************************************************************\n",
|
37 |
+
"#### OF NOTE: Please ensure the correct token dictionary, gene median file, special token setting, and model input size is used for the correct model version.\n",
|
38 |
+
"#### Current defaults are for V2 model series. To auto-select the correct settings for V1, set model_version argument to \"V1\"."
|
|
|
|
|
|
|
|
|
39 |
]
|
40 |
},
|
41 |
{
|
|
|
55 |
"metadata": {},
|
56 |
"outputs": [],
|
57 |
"source": [
|
58 |
+
"tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16) # for V1 model, set model_version=\"V1\"\n",
|
59 |
"tk.tokenize_data(\"loom_data_directory\", \n",
|
60 |
" \"output_directory\", \n",
|
61 |
" \"output_prefix\", \n",
|
|
|
79 |
"name": "python",
|
80 |
"nbconvert_exporter": "python",
|
81 |
"pygments_lexer": "ipython3",
|
82 |
+
"version": "3.10.13"
|
83 |
}
|
84 |
},
|
85 |
"nbformat": 4,
|
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"BertForMaskedLM"
|
4 |
-
],
|
5 |
-
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"classifier_dropout": null,
|
7 |
-
"hidden_act": "relu",
|
8 |
-
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size": 512,
|
10 |
-
"initializer_range": 0.02,
|
11 |
-
"intermediate_size": 1024,
|
12 |
-
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings": 4096,
|
14 |
-
"model_type": "bert",
|
15 |
-
"num_attention_heads": 8,
|
16 |
-
"num_hidden_layers": 12,
|
17 |
-
"pad_token_id": 0,
|
18 |
-
"position_embedding_type": "absolute",
|
19 |
-
"torch_dtype": "float32",
|
20 |
-
"transformers_version": "4.37.2",
|
21 |
-
"type_vocab_size": 2,
|
22 |
-
"use_cache": true,
|
23 |
-
"vocab_size": 20275
|
24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
|
3 |
-
size 152363342
|
|
|
|
|
|
|
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_name_or_path": "/n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/",
|
3 |
-
"architectures": [
|
4 |
-
"BertForSequenceClassification"
|
5 |
-
],
|
6 |
-
"attention_probs_dropout_prob": 0.02,
|
7 |
-
"gradient_checkpointing": false,
|
8 |
-
"hidden_act": "relu",
|
9 |
-
"hidden_dropout_prob": 0.02,
|
10 |
-
"hidden_size": 256,
|
11 |
-
"id2label": {
|
12 |
-
"0": "LABEL_0",
|
13 |
-
"1": "LABEL_1",
|
14 |
-
"2": "LABEL_2"
|
15 |
-
},
|
16 |
-
"initializer_range": 0.02,
|
17 |
-
"intermediate_size": 512,
|
18 |
-
"label2id": {
|
19 |
-
"LABEL_0": 0,
|
20 |
-
"LABEL_1": 1,
|
21 |
-
"LABEL_2": 2
|
22 |
-
},
|
23 |
-
"layer_norm_eps": 1e-12,
|
24 |
-
"max_position_embeddings": 2048,
|
25 |
-
"model_type": "bert",
|
26 |
-
"num_attention_heads": 4,
|
27 |
-
"num_hidden_layers": 6,
|
28 |
-
"pad_token_id": 0,
|
29 |
-
"position_embedding_type": "absolute",
|
30 |
-
"problem_type": "single_label_classification",
|
31 |
-
"transformers_version": "4.6.0",
|
32 |
-
"type_vocab_size": 2,
|
33 |
-
"use_cache": true,
|
34 |
-
"vocab_size": 25426
|
35 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3ced328122d57a847fc3914732337674500e259a82e64437c67b4954ac2f4e07
|
3 |
-
size 73720721
|
|
|
|
|
|
|
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:12ead3bad8cf4b853bac87eadeb79c9308ae492e9d29f32da1a2c85e8586108d
|
3 |
-
size 41115113
|
|
|
|
|
|
|
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:dd8c0a739c2fe6a9ab4bb8f4a62ad8d7b879efcdceb5376b128a2040ff1bbe62
|
3 |
-
size 14657
|
|
|
|
|
|
|
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3d0797845afdae765a74ddab7966e0e1837617fd8171af8ee6aef9dedce248f2
|
3 |
-
size 623
|
|
|
|
|
|
|
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json
DELETED
@@ -1,150 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"best_metric": 0.39658036828041077,
|
3 |
-
"best_model_checkpoint": "/n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/models/220224_geneformer_27M_SequenceClassifier_tuning_hCMdCM_L2048_B12_LR1e-05_LScosine_WU500_E1_Oadamw_F2/run-8429a330/checkpoint-7020",
|
4 |
-
"epoch": 0.9,
|
5 |
-
"global_step": 7020,
|
6 |
-
"is_hyper_param_search": true,
|
7 |
-
"is_local_process_zero": true,
|
8 |
-
"is_world_process_zero": true,
|
9 |
-
"log_history": [
|
10 |
-
{
|
11 |
-
"epoch": 0.1,
|
12 |
-
"learning_rate": 0.00034606438343856935,
|
13 |
-
"loss": 0.911,
|
14 |
-
"step": 780
|
15 |
-
},
|
16 |
-
{
|
17 |
-
"epoch": 0.1,
|
18 |
-
"eval_accuracy": 0.4531576503366612,
|
19 |
-
"eval_loss": 1.4550466537475586,
|
20 |
-
"eval_runtime": 66.5164,
|
21 |
-
"eval_samples_per_second": 259.004,
|
22 |
-
"step": 780
|
23 |
-
},
|
24 |
-
{
|
25 |
-
"epoch": 0.2,
|
26 |
-
"learning_rate": 0.0006921287668771387,
|
27 |
-
"loss": 0.6273,
|
28 |
-
"step": 1560
|
29 |
-
},
|
30 |
-
{
|
31 |
-
"epoch": 0.2,
|
32 |
-
"eval_accuracy": 0.5953680055723242,
|
33 |
-
"eval_loss": 0.846651554107666,
|
34 |
-
"eval_runtime": 66.1267,
|
35 |
-
"eval_samples_per_second": 260.53,
|
36 |
-
"step": 1560
|
37 |
-
},
|
38 |
-
{
|
39 |
-
"epoch": 0.3,
|
40 |
-
"learning_rate": 0.0007330550166223805,
|
41 |
-
"loss": 0.5592,
|
42 |
-
"step": 2340
|
43 |
-
},
|
44 |
-
{
|
45 |
-
"epoch": 0.3,
|
46 |
-
"eval_accuracy": 0.5935105641978176,
|
47 |
-
"eval_loss": 1.0599186420440674,
|
48 |
-
"eval_runtime": 66.2608,
|
49 |
-
"eval_samples_per_second": 260.003,
|
50 |
-
"step": 2340
|
51 |
-
},
|
52 |
-
{
|
53 |
-
"epoch": 0.4,
|
54 |
-
"learning_rate": 0.0006283471571048975,
|
55 |
-
"loss": 0.3714,
|
56 |
-
"step": 3120
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"epoch": 0.4,
|
60 |
-
"eval_accuracy": 0.686324587880195,
|
61 |
-
"eval_loss": 1.184874415397644,
|
62 |
-
"eval_runtime": 66.1411,
|
63 |
-
"eval_samples_per_second": 260.473,
|
64 |
-
"step": 3120
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"epoch": 0.5,
|
68 |
-
"learning_rate": 0.0005236392975874146,
|
69 |
-
"loss": 0.2976,
|
70 |
-
"step": 3900
|
71 |
-
},
|
72 |
-
{
|
73 |
-
"epoch": 0.5,
|
74 |
-
"eval_accuracy": 0.7681100534014396,
|
75 |
-
"eval_loss": 0.6318939328193665,
|
76 |
-
"eval_runtime": 66.3309,
|
77 |
-
"eval_samples_per_second": 259.728,
|
78 |
-
"step": 3900
|
79 |
-
},
|
80 |
-
{
|
81 |
-
"epoch": 0.6,
|
82 |
-
"learning_rate": 0.0004189314380699318,
|
83 |
-
"loss": 0.2564,
|
84 |
-
"step": 4680
|
85 |
-
},
|
86 |
-
{
|
87 |
-
"epoch": 0.6,
|
88 |
-
"eval_accuracy": 0.7807058277223126,
|
89 |
-
"eval_loss": 0.7283642888069153,
|
90 |
-
"eval_runtime": 66.3416,
|
91 |
-
"eval_samples_per_second": 259.686,
|
92 |
-
"step": 4680
|
93 |
-
},
|
94 |
-
{
|
95 |
-
"epoch": 0.7,
|
96 |
-
"learning_rate": 0.0003142235785524487,
|
97 |
-
"loss": 0.2336,
|
98 |
-
"step": 5460
|
99 |
-
},
|
100 |
-
{
|
101 |
-
"epoch": 0.7,
|
102 |
-
"eval_accuracy": 0.8563965637334572,
|
103 |
-
"eval_loss": 0.5184123516082764,
|
104 |
-
"eval_runtime": 66.3416,
|
105 |
-
"eval_samples_per_second": 259.686,
|
106 |
-
"step": 5460
|
107 |
-
},
|
108 |
-
{
|
109 |
-
"epoch": 0.8,
|
110 |
-
"learning_rate": 0.0002095157190349659,
|
111 |
-
"loss": 0.1731,
|
112 |
-
"step": 6240
|
113 |
-
},
|
114 |
-
{
|
115 |
-
"epoch": 0.8,
|
116 |
-
"eval_accuracy": 0.8288832133735778,
|
117 |
-
"eval_loss": 0.5823884010314941,
|
118 |
-
"eval_runtime": 66.1535,
|
119 |
-
"eval_samples_per_second": 260.425,
|
120 |
-
"step": 6240
|
121 |
-
},
|
122 |
-
{
|
123 |
-
"epoch": 0.9,
|
124 |
-
"learning_rate": 0.00010480785951748295,
|
125 |
-
"loss": 0.1451,
|
126 |
-
"step": 7020
|
127 |
-
},
|
128 |
-
{
|
129 |
-
"epoch": 0.9,
|
130 |
-
"eval_accuracy": 0.886812166241003,
|
131 |
-
"eval_loss": 0.39658036828041077,
|
132 |
-
"eval_runtime": 66.3555,
|
133 |
-
"eval_samples_per_second": 259.632,
|
134 |
-
"step": 7020
|
135 |
-
}
|
136 |
-
],
|
137 |
-
"max_steps": 7800,
|
138 |
-
"num_train_epochs": 1,
|
139 |
-
"total_flos": 0,
|
140 |
-
"trial_name": null,
|
141 |
-
"trial_params": {
|
142 |
-
"learning_rate": 0.0008039341830649843,
|
143 |
-
"lr_scheduler_type": "polynomial",
|
144 |
-
"num_train_epochs": 1,
|
145 |
-
"per_device_train_batch_size": 12,
|
146 |
-
"seed": 73.15243080311434,
|
147 |
-
"warmup_steps": 1812.6785581609881,
|
148 |
-
"weight_decay": 0.2588277764570262
|
149 |
-
}
|
150 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:4ffee119596c99b50a422b2f80103f4c44f7e25c2ea0e457fe224bad59f1f955
|
3 |
-
size 2607
|
|
|
|
|
|
|
|
geneformer/__init__.py
CHANGED
@@ -4,10 +4,15 @@ from pathlib import Path
|
|
4 |
|
5 |
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
|
6 |
|
7 |
-
GENE_MEDIAN_FILE = Path(__file__).parent / "
|
8 |
-
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "
|
9 |
-
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "
|
10 |
-
ENSEMBL_MAPPING_FILE = Path(__file__).parent / "
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
from . import (
|
13 |
collator_for_classification,
|
|
|
4 |
|
5 |
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
|
6 |
|
7 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc104M.pkl"
|
8 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc104M.pkl"
|
9 |
+
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc104M.pkl"
|
10 |
+
ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc104M.pkl"
|
11 |
+
|
12 |
+
GENE_MEDIAN_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl"
|
13 |
+
TOKEN_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/token_dictionary_gc30M.pkl"
|
14 |
+
ENSEMBL_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl"
|
15 |
+
ENSEMBL_MAPPING_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl"
|
16 |
|
17 |
from . import (
|
18 |
collator_for_classification,
|
geneformer/classifier.py
CHANGED
@@ -92,6 +92,7 @@ class Classifier:
|
|
92 |
"no_eval": {bool},
|
93 |
"stratify_splits_col": {None, str},
|
94 |
"forward_batch_size": {int},
|
|
|
95 |
"token_dictionary_file": {None, str},
|
96 |
"nproc": {int},
|
97 |
"ngpu": {int},
|
@@ -115,6 +116,7 @@ class Classifier:
|
|
115 |
stratify_splits_col=None,
|
116 |
no_eval=False,
|
117 |
forward_batch_size=100,
|
|
|
118 |
token_dictionary_file=None,
|
119 |
nproc=4,
|
120 |
ngpu=1,
|
@@ -191,6 +193,9 @@ class Classifier:
|
|
191 |
| Otherwise, will perform eval during training.
|
192 |
forward_batch_size : int
|
193 |
| Batch size for forward pass (for evaluation, not training).
|
|
|
|
|
|
|
194 |
token_dictionary_file : None, str
|
195 |
| Default is to use token dictionary file from Geneformer
|
196 |
| Otherwise, will load custom gene token dictionary.
|
@@ -225,14 +230,20 @@ class Classifier:
|
|
225 |
self.stratify_splits_col = stratify_splits_col
|
226 |
self.no_eval = no_eval
|
227 |
self.forward_batch_size = forward_batch_size
|
|
|
228 |
self.token_dictionary_file = token_dictionary_file
|
229 |
self.nproc = nproc
|
230 |
self.ngpu = ngpu
|
231 |
|
|
|
|
|
|
|
|
|
232 |
if self.training_args is None:
|
233 |
logger.warning(
|
234 |
"Hyperparameter tuning is highly recommended for optimal results. "
|
235 |
-
"No training_args provided; using default hyperparameters."
|
|
|
236 |
)
|
237 |
|
238 |
self.validate_options()
|
@@ -1319,7 +1330,7 @@ class Classifier:
|
|
1319 |
##### Evaluate the model #####
|
1320 |
labels = id_class_dict.keys()
|
1321 |
y_pred, y_true, logits_list = eu.classifier_predict(
|
1322 |
-
model, self.classifier, eval_data, self.forward_batch_size
|
1323 |
)
|
1324 |
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
1325 |
y_pred, y_true, logits_list, num_classes, labels
|
|
|
92 |
"no_eval": {bool},
|
93 |
"stratify_splits_col": {None, str},
|
94 |
"forward_batch_size": {int},
|
95 |
+
"model_version": {"V1", "V2"},
|
96 |
"token_dictionary_file": {None, str},
|
97 |
"nproc": {int},
|
98 |
"ngpu": {int},
|
|
|
116 |
stratify_splits_col=None,
|
117 |
no_eval=False,
|
118 |
forward_batch_size=100,
|
119 |
+
model_version="V2",
|
120 |
token_dictionary_file=None,
|
121 |
nproc=4,
|
122 |
ngpu=1,
|
|
|
193 |
| Otherwise, will perform eval during training.
|
194 |
forward_batch_size : int
|
195 |
| Batch size for forward pass (for evaluation, not training).
|
196 |
+
model_version : str
|
197 |
+
| To auto-select settings for model version other than current default.
|
198 |
+
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
199 |
token_dictionary_file : None, str
|
200 |
| Default is to use token dictionary file from Geneformer
|
201 |
| Otherwise, will load custom gene token dictionary.
|
|
|
230 |
self.stratify_splits_col = stratify_splits_col
|
231 |
self.no_eval = no_eval
|
232 |
self.forward_batch_size = forward_batch_size
|
233 |
+
self.model_version = model_version
|
234 |
self.token_dictionary_file = token_dictionary_file
|
235 |
self.nproc = nproc
|
236 |
self.ngpu = ngpu
|
237 |
|
238 |
+
if self.model_version == "V1":
|
239 |
+
from . import TOKEN_DICTIONARY_FILE_30M
|
240 |
+
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
241 |
+
|
242 |
if self.training_args is None:
|
243 |
logger.warning(
|
244 |
"Hyperparameter tuning is highly recommended for optimal results. "
|
245 |
+
"No training_args provided; using default hyperparameters. "
|
246 |
+
"Please note: these defaults are not recommended to be used uniformly across tasks."
|
247 |
)
|
248 |
|
249 |
self.validate_options()
|
|
|
1330 |
##### Evaluate the model #####
|
1331 |
labels = id_class_dict.keys()
|
1332 |
y_pred, y_true, logits_list = eu.classifier_predict(
|
1333 |
+
model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict
|
1334 |
)
|
1335 |
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
1336 |
y_pred, y_true, logits_list, num_classes, labels
|
geneformer/emb_extractor.py
CHANGED
@@ -402,6 +402,7 @@ class EmbExtractor:
|
|
402 |
"emb_label": {None, list},
|
403 |
"labels_to_plot": {None, list},
|
404 |
"forward_batch_size": {int},
|
|
|
405 |
"token_dictionary_file": {None, str},
|
406 |
"nproc": {int},
|
407 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
@@ -422,6 +423,7 @@ class EmbExtractor:
|
|
422 |
forward_batch_size=100,
|
423 |
nproc=4,
|
424 |
summary_stat=None,
|
|
|
425 |
token_dictionary_file=None,
|
426 |
):
|
427 |
"""
|
@@ -472,6 +474,9 @@ class EmbExtractor:
|
|
472 |
| If mean or median, outputs only approximated mean or median embedding of input data.
|
473 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
474 |
| Non-exact is slower but more memory-efficient.
|
|
|
|
|
|
|
475 |
token_dictionary_file : Path
|
476 |
| Default is the Geneformer token dictionary
|
477 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
@@ -502,6 +507,7 @@ class EmbExtractor:
|
|
502 |
self.emb_layer = emb_layer
|
503 |
self.emb_label = emb_label
|
504 |
self.labels_to_plot = labels_to_plot
|
|
|
505 |
self.token_dictionary_file = token_dictionary_file
|
506 |
self.forward_batch_size = forward_batch_size
|
507 |
self.nproc = nproc
|
@@ -512,6 +518,15 @@ class EmbExtractor:
|
|
512 |
self.summary_stat = summary_stat
|
513 |
self.exact_summary_stat = None
|
514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
self.validate_options()
|
516 |
|
517 |
# load token dictionary (Ensembl IDs:token)
|
|
|
402 |
"emb_label": {None, list},
|
403 |
"labels_to_plot": {None, list},
|
404 |
"forward_batch_size": {int},
|
405 |
+
"model_version": {"V1", "V2"},
|
406 |
"token_dictionary_file": {None, str},
|
407 |
"nproc": {int},
|
408 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
|
|
423 |
forward_batch_size=100,
|
424 |
nproc=4,
|
425 |
summary_stat=None,
|
426 |
+
model_version="V2",
|
427 |
token_dictionary_file=None,
|
428 |
):
|
429 |
"""
|
|
|
474 |
| If mean or median, outputs only approximated mean or median embedding of input data.
|
475 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
476 |
| Non-exact is slower but more memory-efficient.
|
477 |
+
model_version : str
|
478 |
+
| To auto-select settings for model version other than current default.
|
479 |
+
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
480 |
token_dictionary_file : Path
|
481 |
| Default is the Geneformer token dictionary
|
482 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
|
|
507 |
self.emb_layer = emb_layer
|
508 |
self.emb_label = emb_label
|
509 |
self.labels_to_plot = labels_to_plot
|
510 |
+
self.model_version = model_version
|
511 |
self.token_dictionary_file = token_dictionary_file
|
512 |
self.forward_batch_size = forward_batch_size
|
513 |
self.nproc = nproc
|
|
|
518 |
self.summary_stat = summary_stat
|
519 |
self.exact_summary_stat = None
|
520 |
|
521 |
+
if self.model_version == "V1":
|
522 |
+
from . import TOKEN_DICTIONARY_FILE_30M
|
523 |
+
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
524 |
+
if self.emb_mode == "cls":
|
525 |
+
self.emb_mode = "cell"
|
526 |
+
logger.warning(
|
527 |
+
"model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
|
528 |
+
)
|
529 |
+
|
530 |
self.validate_options()
|
531 |
|
532 |
# load token dictionary (Ensembl IDs:token)
|
geneformer/ensembl_mapping_dict_gc95M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0819bcbd869cfa14279449b037eb9ed1d09a91310e77bd1a19d927465030e95c
|
3 |
-
size 3957652
|
|
|
|
|
|
|
|
geneformer/evaluation_utils.py
CHANGED
@@ -20,20 +20,15 @@ from sklearn.metrics import (
|
|
20 |
)
|
21 |
from tqdm.auto import trange
|
22 |
|
23 |
-
from . import TOKEN_DICTIONARY_FILE
|
24 |
from .emb_extractor import make_colorbar
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
28 |
|
29 |
-
def preprocess_classifier_batch(cell_batch, max_len, label_name):
|
30 |
if max_len is None:
|
31 |
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
32 |
|
33 |
-
# load token dictionary (Ensembl IDs:token)
|
34 |
-
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
35 |
-
gene_token_dict = pickle.load(f)
|
36 |
-
|
37 |
def pad_label_example(example):
|
38 |
example[label_name] = np.pad(
|
39 |
example[label_name],
|
@@ -81,7 +76,7 @@ def py_softmax(vector):
|
|
81 |
return e / e.sum()
|
82 |
|
83 |
|
84 |
-
def classifier_predict(model, classifier_type, evalset, forward_batch_size):
|
85 |
if classifier_type == "gene":
|
86 |
label_name = "labels"
|
87 |
elif classifier_type == "cell":
|
@@ -104,7 +99,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size):
|
|
104 |
max_range = min(i + forward_batch_size, evalset_len)
|
105 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
106 |
padded_batch = preprocess_classifier_batch(
|
107 |
-
batch_evalset, max_evalset_len, label_name
|
108 |
)
|
109 |
padded_batch.set_format(type="torch")
|
110 |
|
|
|
20 |
)
|
21 |
from tqdm.auto import trange
|
22 |
|
|
|
23 |
from .emb_extractor import make_colorbar
|
24 |
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
27 |
|
28 |
+
def preprocess_classifier_batch(cell_batch, max_len, label_name, gene_token_dict):
|
29 |
if max_len is None:
|
30 |
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
31 |
|
|
|
|
|
|
|
|
|
32 |
def pad_label_example(example):
|
33 |
example[label_name] = np.pad(
|
34 |
example[label_name],
|
|
|
76 |
return e / e.sum()
|
77 |
|
78 |
|
79 |
+
def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict):
|
80 |
if classifier_type == "gene":
|
81 |
label_name = "labels"
|
82 |
elif classifier_type == "cell":
|
|
|
99 |
max_range = min(i + forward_batch_size, evalset_len)
|
100 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
101 |
padded_batch = preprocess_classifier_batch(
|
102 |
+
batch_evalset, max_evalset_len, label_name, gene_token_dict
|
103 |
)
|
104 |
padded_batch.set_format(type="torch")
|
105 |
|
geneformer/gene_median_dictionary_gc95M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a51c53f6a771d64508dfaf61529df70e394c53bd20856926117ae5d641a24bf5
|
3 |
-
size 1512661
|
|
|
|
|
|
|
|
geneformer/gene_name_id_dict_gc95M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:fabfa0c2f49c598c59ae432a32c3499a5908c033756c663b5e0cddf58deea8e1
|
3 |
-
size 1660882
|
|
|
|
|
|
|
|
geneformer/in_silico_perturber.py
CHANGED
@@ -72,6 +72,7 @@ class InSilicoPerturber:
|
|
72 |
"max_ncells": {None, int},
|
73 |
"cell_inds_to_perturb": {"all", dict},
|
74 |
"emb_layer": {-1, 0},
|
|
|
75 |
"token_dictionary_file": {None, str},
|
76 |
"forward_batch_size": {int},
|
77 |
"nproc": {int},
|
@@ -96,6 +97,7 @@ class InSilicoPerturber:
|
|
96 |
emb_layer=-1,
|
97 |
forward_batch_size=100,
|
98 |
nproc=4,
|
|
|
99 |
token_dictionary_file=None,
|
100 |
clear_mem_ncells=1000,
|
101 |
):
|
@@ -184,6 +186,9 @@ class InSilicoPerturber:
|
|
184 |
| Batch size for forward pass.
|
185 |
nproc : int
|
186 |
| Number of CPU processes to use.
|
|
|
|
|
|
|
187 |
token_dictionary_file : Path
|
188 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
189 |
clear_mem_ncells : int
|
@@ -224,9 +229,24 @@ class InSilicoPerturber:
|
|
224 |
self.emb_layer = emb_layer
|
225 |
self.forward_batch_size = forward_batch_size
|
226 |
self.nproc = nproc
|
|
|
227 |
self.token_dictionary_file = token_dictionary_file
|
228 |
self.clear_mem_ncells = clear_mem_ncells
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
self.validate_options()
|
231 |
|
232 |
# load token dictionary (Ensembl IDs:token)
|
|
|
72 |
"max_ncells": {None, int},
|
73 |
"cell_inds_to_perturb": {"all", dict},
|
74 |
"emb_layer": {-1, 0},
|
75 |
+
"model_version": {"V1", "V2"},
|
76 |
"token_dictionary_file": {None, str},
|
77 |
"forward_batch_size": {int},
|
78 |
"nproc": {int},
|
|
|
97 |
emb_layer=-1,
|
98 |
forward_batch_size=100,
|
99 |
nproc=4,
|
100 |
+
model_version="V2",
|
101 |
token_dictionary_file=None,
|
102 |
clear_mem_ncells=1000,
|
103 |
):
|
|
|
186 |
| Batch size for forward pass.
|
187 |
nproc : int
|
188 |
| Number of CPU processes to use.
|
189 |
+
model_version : str
|
190 |
+
| To auto-select settings for model version other than current default.
|
191 |
+
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
192 |
token_dictionary_file : Path
|
193 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
194 |
clear_mem_ncells : int
|
|
|
229 |
self.emb_layer = emb_layer
|
230 |
self.forward_batch_size = forward_batch_size
|
231 |
self.nproc = nproc
|
232 |
+
self.model_version = model_version
|
233 |
self.token_dictionary_file = token_dictionary_file
|
234 |
self.clear_mem_ncells = clear_mem_ncells
|
235 |
|
236 |
+
if self.model_version == "V1":
|
237 |
+
from . import TOKEN_DICTIONARY_FILE_30M
|
238 |
+
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
239 |
+
if self.emb_mode == "cls":
|
240 |
+
self.emb_mode = "cell"
|
241 |
+
logger.warning(
|
242 |
+
"model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a <cls> token."
|
243 |
+
)
|
244 |
+
if self.emb_mode == "cls_and_gene":
|
245 |
+
self.emb_mode = "cell_and_gene"
|
246 |
+
logger.warning(
|
247 |
+
"model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a <cls> token."
|
248 |
+
)
|
249 |
+
|
250 |
self.validate_options()
|
251 |
|
252 |
# load token dictionary (Ensembl IDs:token)
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -676,6 +676,7 @@ class InSilicoPerturberStats:
|
|
676 |
"anchor_gene": {None, str},
|
677 |
"cell_states_to_model": {None, dict},
|
678 |
"pickle_suffix": {None, str},
|
|
|
679 |
}
|
680 |
|
681 |
def __init__(
|
@@ -686,6 +687,7 @@ class InSilicoPerturberStats:
|
|
686 |
anchor_gene=None,
|
687 |
cell_states_to_model=None,
|
688 |
pickle_suffix="_raw.pickle",
|
|
|
689 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
690 |
gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
|
691 |
):
|
@@ -713,7 +715,7 @@ class InSilicoPerturberStats:
|
|
713 |
| analyzes data for anchor gene perturbed in combination with each other gene.
|
714 |
| However, if combos=0 and anchor_gene="ENSG00000136574":
|
715 |
| analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
|
716 |
-
cell_states_to_model: None, dict
|
717 |
| Cell states to model if testing perturbations that achieve goal state change.
|
718 |
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
719 |
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
@@ -724,6 +726,9 @@ class InSilicoPerturberStats:
|
|
724 |
| "start_state": "dcm",
|
725 |
| "goal_state": "nf",
|
726 |
| "alt_states": ["hcm", "other1", "other2"]}
|
|
|
|
|
|
|
727 |
token_dictionary_file : Path
|
728 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
729 |
gene_name_id_dictionary_file : Path
|
@@ -736,9 +741,15 @@ class InSilicoPerturberStats:
|
|
736 |
self.anchor_gene = anchor_gene
|
737 |
self.cell_states_to_model = cell_states_to_model
|
738 |
self.pickle_suffix = pickle_suffix
|
|
|
739 |
|
740 |
self.validate_options()
|
741 |
|
|
|
|
|
|
|
|
|
|
|
742 |
# load token dictionary (Ensembl IDs:token)
|
743 |
with open(token_dictionary_file, "rb") as f:
|
744 |
self.gene_token_dict = pickle.load(f)
|
|
|
676 |
"anchor_gene": {None, str},
|
677 |
"cell_states_to_model": {None, dict},
|
678 |
"pickle_suffix": {None, str},
|
679 |
+
"model_version": {"V1", "V2"},
|
680 |
}
|
681 |
|
682 |
def __init__(
|
|
|
687 |
anchor_gene=None,
|
688 |
cell_states_to_model=None,
|
689 |
pickle_suffix="_raw.pickle",
|
690 |
+
model_version="V2",
|
691 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
692 |
gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
|
693 |
):
|
|
|
715 |
| analyzes data for anchor gene perturbed in combination with each other gene.
|
716 |
| However, if combos=0 and anchor_gene="ENSG00000136574":
|
717 |
| analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
|
718 |
+
cell_states_to_model : None, dict
|
719 |
| Cell states to model if testing perturbations that achieve goal state change.
|
720 |
| Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
|
721 |
| state_key: key specifying name of column in .dataset that defines the start/goal states
|
|
|
726 |
| "start_state": "dcm",
|
727 |
| "goal_state": "nf",
|
728 |
| "alt_states": ["hcm", "other1", "other2"]}
|
729 |
+
model_version : str
|
730 |
+
| To auto-select settings for model version other than current default.
|
731 |
+
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
732 |
token_dictionary_file : Path
|
733 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
734 |
gene_name_id_dictionary_file : Path
|
|
|
741 |
self.anchor_gene = anchor_gene
|
742 |
self.cell_states_to_model = cell_states_to_model
|
743 |
self.pickle_suffix = pickle_suffix
|
744 |
+
self.model_version = model_version
|
745 |
|
746 |
self.validate_options()
|
747 |
|
748 |
+
if self.model_version == "V1":
|
749 |
+
from . import ENSEMBL_DICTIONARY_FILE_30M, TOKEN_DICTIONARY_FILE_30M
|
750 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE_30M
|
751 |
+
gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE_30M
|
752 |
+
|
753 |
# load token dictionary (Ensembl IDs:token)
|
754 |
with open(token_dictionary_file, "rb") as f:
|
755 |
self.gene_token_dict = pickle.load(f)
|
geneformer/perturber_utils.py
CHANGED
@@ -17,11 +17,6 @@ from transformers import (
|
|
17 |
BitsAndBytesConfig,
|
18 |
)
|
19 |
|
20 |
-
from . import (
|
21 |
-
TOKEN_DICTIONARY_FILE,
|
22 |
-
ENSEMBL_DICTIONARY_FILE,
|
23 |
-
)
|
24 |
-
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
27 |
|
@@ -127,7 +122,10 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
127 |
output_hidden_states = (mode == "eval")
|
128 |
|
129 |
# Quantization logic
|
130 |
-
if quantize:
|
|
|
|
|
|
|
131 |
if inference_only:
|
132 |
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
133 |
peft_config = None
|
@@ -138,19 +136,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
138 |
bnb_4bit_quant_type="nf4",
|
139 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
140 |
)
|
141 |
-
lora_config_params = {
|
142 |
-
"lora_alpha": 128,
|
143 |
-
"lora_dropout": 0.1,
|
144 |
-
"r": 64,
|
145 |
-
"bias": "none"
|
146 |
-
}
|
147 |
-
|
148 |
-
# Try with TokenClassification first, fallback to TOKEN_CLS if needed
|
149 |
try:
|
150 |
-
peft_config = LoraConfig(
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
else:
|
155 |
quantize_config = None
|
156 |
peft_config = None
|
@@ -187,14 +188,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
187 |
model.eval()
|
188 |
|
189 |
# Handle device placement and PEFT
|
|
|
|
|
190 |
if not quantize:
|
191 |
# Only move non-quantized models
|
192 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
193 |
model = model.to(device)
|
|
|
|
|
|
|
|
|
|
|
194 |
elif peft_config:
|
195 |
# Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
|
196 |
model.enable_input_require_grads()
|
197 |
model = get_peft_model(model, peft_config)
|
|
|
198 |
|
199 |
return model
|
200 |
|
@@ -883,50 +892,4 @@ def validate_cell_states_to_model(cell_states_to_model):
|
|
883 |
"'goal_state': 'nf', "
|
884 |
"'alt_states': ['hcm', 'other1', 'other2']}"
|
885 |
)
|
886 |
-
raise
|
887 |
-
|
888 |
-
|
889 |
-
class GeneIdHandler:
|
890 |
-
def __init__(self, raise_errors=False):
|
891 |
-
def invert_dict(dict_obj):
|
892 |
-
return {v: k for k, v in dict_obj.items()}
|
893 |
-
|
894 |
-
self.raise_errors = raise_errors
|
895 |
-
|
896 |
-
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
897 |
-
self.gene_token_dict = pickle.load(f)
|
898 |
-
self.token_gene_dict = invert_dict(self.gene_token_dict)
|
899 |
-
|
900 |
-
with open(ENSEMBL_DICTIONARY_FILE, "rb") as f:
|
901 |
-
self.id_gene_dict = pickle.load(f)
|
902 |
-
self.gene_id_dict = invert_dict(self.id_gene_dict)
|
903 |
-
|
904 |
-
def ens_to_token(self, ens_id):
|
905 |
-
if not self.raise_errors:
|
906 |
-
return self.gene_token_dict.get(ens_id, ens_id)
|
907 |
-
else:
|
908 |
-
return self.gene_token_dict[ens_id]
|
909 |
-
|
910 |
-
def token_to_ens(self, token):
|
911 |
-
if not self.raise_errors:
|
912 |
-
return self.token_gene_dict.get(token, token)
|
913 |
-
else:
|
914 |
-
return self.token_gene_dict[token]
|
915 |
-
|
916 |
-
def ens_to_symbol(self, ens_id):
|
917 |
-
if not self.raise_errors:
|
918 |
-
return self.gene_id_dict.get(ens_id, ens_id)
|
919 |
-
else:
|
920 |
-
return self.gene_id_dict[ens_id]
|
921 |
-
|
922 |
-
def symbol_to_ens(self, symbol):
|
923 |
-
if not self.raise_errors:
|
924 |
-
return self.id_gene_dict.get(symbol, symbol)
|
925 |
-
else:
|
926 |
-
return self.id_gene_dict[symbol]
|
927 |
-
|
928 |
-
def token_to_symbol(self, token):
|
929 |
-
return self.ens_to_symbol(self.token_to_ens(token))
|
930 |
-
|
931 |
-
def symbol_to_token(self, symbol):
|
932 |
-
return self.ens_to_token(self.symbol_to_ens(symbol))
|
|
|
17 |
BitsAndBytesConfig,
|
18 |
)
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
|
|
122 |
output_hidden_states = (mode == "eval")
|
123 |
|
124 |
# Quantization logic
|
125 |
+
if isinstance(quantize, dict):
|
126 |
+
quantize_config = quantize.get("bnb_config", None)
|
127 |
+
peft_config = quantize.get("peft_config", None)
|
128 |
+
elif quantize:
|
129 |
if inference_only:
|
130 |
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
131 |
peft_config = None
|
|
|
136 |
bnb_4bit_quant_type="nf4",
|
137 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
138 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
try:
|
140 |
+
peft_config = LoraConfig(
|
141 |
+
lora_alpha=128,
|
142 |
+
lora_dropout=0.1,
|
143 |
+
r=64,
|
144 |
+
bias="none",
|
145 |
+
task_type="TokenClassification",
|
146 |
+
)
|
147 |
+
except ValueError as e:
|
148 |
+
peft_config = LoraConfig(
|
149 |
+
lora_alpha=128,
|
150 |
+
lora_dropout=0.1,
|
151 |
+
r=64,
|
152 |
+
bias="none",
|
153 |
+
task_type="TOKEN_CLS",
|
154 |
+
)
|
155 |
else:
|
156 |
quantize_config = None
|
157 |
peft_config = None
|
|
|
188 |
model.eval()
|
189 |
|
190 |
# Handle device placement and PEFT
|
191 |
+
adapter_config_path = os.path.join(model_directory, "adapter_config.json")
|
192 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
193 |
if not quantize:
|
194 |
# Only move non-quantized models
|
195 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
196 |
model = model.to(device)
|
197 |
+
elif os.path.exists(adapter_config_path):
|
198 |
+
# If adapter files exist, load them into the model using PEFT's from_pretrained
|
199 |
+
model = PeftModel.from_pretrained(model, model_directory)
|
200 |
+
model = model.to(device)
|
201 |
+
print("loading lora weights")
|
202 |
elif peft_config:
|
203 |
# Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
|
204 |
model.enable_input_require_grads()
|
205 |
model = get_peft_model(model, peft_config)
|
206 |
+
model = model.to(device)
|
207 |
|
208 |
return model
|
209 |
|
|
|
892 |
"'goal_state': 'nf', "
|
893 |
"'alt_states': ['hcm', 'other1', 'other2']}"
|
894 |
)
|
895 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/token_dictionary_gc95M.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:67c445f4385127adfc48dcc072320cd65d6822829bf27dd38070e6e787bc597f
|
3 |
-
size 425590
|
|
|
|
|
|
|
|
geneformer/tokenizer.py
CHANGED
@@ -32,9 +32,7 @@ Geneformer tokenizer.
|
|
32 |
|
33 |
| If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
|
34 |
|
35 |
-
| OF NOTE:
|
36 |
-
|
37 |
-
| OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048.
|
38 |
|
39 |
"""
|
40 |
|
@@ -299,6 +297,7 @@ class TranscriptomeTokenizer:
|
|
299 |
model_input_size=4096,
|
300 |
special_token=True,
|
301 |
collapse_gene_ids=True,
|
|
|
302 |
gene_median_file=GENE_MEDIAN_FILE,
|
303 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
304 |
gene_mapping_file=ENSEMBL_MAPPING_FILE,
|
@@ -318,15 +317,18 @@ class TranscriptomeTokenizer:
|
|
318 |
| Chunk size for anndata tokenizer.
|
319 |
model_input_size : int = 4096
|
320 |
| Max input size of model to truncate input to.
|
321 |
-
| For the
|
322 |
special_token : bool = True
|
323 |
| Adds CLS token before and EOS token after rank value encoding.
|
324 |
-
| For the
|
325 |
collapse_gene_ids : bool = True
|
326 |
| Whether to collapse gene IDs based on gene mapping dictionary.
|
|
|
|
|
|
|
327 |
gene_median_file : Path
|
328 |
| Path to pickle file containing dictionary of non-zero median
|
329 |
-
| gene expression values across Genecorpus
|
330 |
token_dictionary_file : Path
|
331 |
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
332 |
gene_mapping_file : None, Path
|
@@ -348,8 +350,22 @@ class TranscriptomeTokenizer:
|
|
348 |
# add CLS and EOS tokens
|
349 |
self.special_token = special_token
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
# load dictionary of gene normalization factors
|
352 |
-
# (non-zero median value of expression across Genecorpus
|
353 |
with open(gene_median_file, "rb") as f:
|
354 |
self.gene_median_dict = pickle.load(f)
|
355 |
|
@@ -372,7 +388,7 @@ class TranscriptomeTokenizer:
|
|
372 |
"<eos>" in self.gene_token_dict.keys()
|
373 |
):
|
374 |
logger.warning(
|
375 |
-
"<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for
|
376 |
)
|
377 |
|
378 |
# if collapsing duplicate gene IDs
|
|
|
32 |
|
33 |
| If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
|
34 |
|
35 |
+
| OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.)
|
|
|
|
|
36 |
|
37 |
"""
|
38 |
|
|
|
297 |
model_input_size=4096,
|
298 |
special_token=True,
|
299 |
collapse_gene_ids=True,
|
300 |
+
model_version="V2",
|
301 |
gene_median_file=GENE_MEDIAN_FILE,
|
302 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
303 |
gene_mapping_file=ENSEMBL_MAPPING_FILE,
|
|
|
317 |
| Chunk size for anndata tokenizer.
|
318 |
model_input_size : int = 4096
|
319 |
| Max input size of model to truncate input to.
|
320 |
+
| For the V1 model series, should be 2048. For the V2 model series, should be 4096.
|
321 |
special_token : bool = True
|
322 |
| Adds CLS token before and EOS token after rank value encoding.
|
323 |
+
| For the V1 model series, should be False. For the V2 model series, should be True.
|
324 |
collapse_gene_ids : bool = True
|
325 |
| Whether to collapse gene IDs based on gene mapping dictionary.
|
326 |
+
model_version : str
|
327 |
+
| To auto-select settings for model version other than current default.
|
328 |
+
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
329 |
gene_median_file : Path
|
330 |
| Path to pickle file containing dictionary of non-zero median
|
331 |
+
| gene expression values across Genecorpus.
|
332 |
token_dictionary_file : Path
|
333 |
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
334 |
gene_mapping_file : None, Path
|
|
|
350 |
# add CLS and EOS tokens
|
351 |
self.special_token = special_token
|
352 |
|
353 |
+
# CHANGE DEFAULTS TO BE FOR MODEL OTHER THAN CURRENT
|
354 |
+
self.model_version = model_version
|
355 |
+
if self.model_version not in ["V1","V2"]:
|
356 |
+
logger.error(
|
357 |
+
"Unrecognized model version. Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells."
|
358 |
+
)
|
359 |
+
elif self.model_version == "V1":
|
360 |
+
self.model_input_size = 2048
|
361 |
+
self.special_token = False
|
362 |
+
from . import ENSEMBL_MAPPING_FILE_30M, GENE_MEDIAN_FILE_30M, TOKEN_DICTIONARY_FILE_30M
|
363 |
+
gene_median_file = GENE_MEDIAN_FILE_30M
|
364 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
365 |
+
gene_mapping_file = ENSEMBL_MAPPING_FILE_30M
|
366 |
+
|
367 |
# load dictionary of gene normalization factors
|
368 |
+
# (non-zero median value of expression across Genecorpus)
|
369 |
with open(gene_median_file, "rb") as f:
|
370 |
self.gene_median_dict = pickle.load(f)
|
371 |
|
|
|
388 |
"<eos>" in self.gene_token_dict.keys()
|
389 |
):
|
390 |
logger.warning(
|
391 |
+
"<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for V2 model series, special_token should be True."
|
392 |
)
|
393 |
|
394 |
# if collapsing duplicate gene IDs
|
generation_config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
"_from_model_config": true,
|
3 |
"pad_token_id": 0,
|
4 |
-
"transformers_version": "4.
|
5 |
}
|
|
|
1 |
{
|
2 |
"_from_model_config": true,
|
3 |
"pad_token_id": 0,
|
4 |
+
"transformers_version": "4.44.2"
|
5 |
}
|
gf-12L-30M-i2048/config.json
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"BertForMaskedLM"
|
4 |
-
],
|
5 |
-
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"gradient_checkpointing": false,
|
7 |
-
"hidden_act": "relu",
|
8 |
-
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size": 512,
|
10 |
-
"initializer_range": 0.02,
|
11 |
-
"intermediate_size": 1024,
|
12 |
-
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings": 2048,
|
14 |
-
"model_type": "bert",
|
15 |
-
"num_attention_heads": 8,
|
16 |
-
"num_hidden_layers": 12,
|
17 |
-
"pad_token_id": 0,
|
18 |
-
"position_embedding_type": "absolute",
|
19 |
-
"transformers_version": "4.6.0",
|
20 |
-
"type_vocab_size": 2,
|
21 |
-
"use_cache": true,
|
22 |
-
"vocab_size": 25426
|
23 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gf-12L-30M-i2048/pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615
|
3 |
-
size 158467410
|
|
|
|
|
|
|
|
gf-12L-30M-i2048/training_args.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:259cf6067211e24e198690d00f0a222ee5550ad57e23d04ced0d0ca2e1b3738e
|
3 |
-
size 2607
|
|
|
|
|
|
|
|
gf-12L-95M-i4096/config.json
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"BertForMaskedLM"
|
4 |
-
],
|
5 |
-
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"classifier_dropout": null,
|
7 |
-
"hidden_act": "relu",
|
8 |
-
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size": 512,
|
10 |
-
"initializer_range": 0.02,
|
11 |
-
"intermediate_size": 1024,
|
12 |
-
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings": 4096,
|
14 |
-
"model_type": "bert",
|
15 |
-
"num_attention_heads": 8,
|
16 |
-
"num_hidden_layers": 12,
|
17 |
-
"pad_token_id": 0,
|
18 |
-
"position_embedding_type": "absolute",
|
19 |
-
"torch_dtype": "float32",
|
20 |
-
"transformers_version": "4.37.1",
|
21 |
-
"type_vocab_size": 2,
|
22 |
-
"use_cache": true,
|
23 |
-
"vocab_size": 20275
|
24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gf-12L-95M-i4096/generation_config.json
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_from_model_config": true,
|
3 |
-
"pad_token_id": 0,
|
4 |
-
"transformers_version": "4.37.1"
|
5 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
gf-12L-95M-i4096/model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c
|
3 |
-
size 152012980
|
|
|
|
|
|
|
|
gf-12L-95M-i4096/training_args.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d
|
3 |
-
size 4920
|
|
|
|
|
|
|
|
gf-12L-95M-i4096_CLcancer/config.json
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_name_or_path": "/gladstone/theodoris/lab/pretrained_models/encoder/240402_194213_geneformer_94M_L12_emb512_SL4096_E3_B4_LR0.0005_LScosine_WU5000_Oadamw_DS8/models",
|
3 |
-
"architectures": [
|
4 |
-
"BertForMaskedLM"
|
5 |
-
],
|
6 |
-
"attention_probs_dropout_prob": 0.02,
|
7 |
-
"classifier_dropout": null,
|
8 |
-
"hidden_act": "relu",
|
9 |
-
"hidden_dropout_prob": 0.02,
|
10 |
-
"hidden_size": 512,
|
11 |
-
"initializer_range": 0.02,
|
12 |
-
"intermediate_size": 1024,
|
13 |
-
"layer_norm_eps": 1e-12,
|
14 |
-
"max_position_embeddings": 4096,
|
15 |
-
"model_type": "bert",
|
16 |
-
"num_attention_heads": 8,
|
17 |
-
"num_hidden_layers": 12,
|
18 |
-
"pad_token_id": 0,
|
19 |
-
"position_embedding_type": "absolute",
|
20 |
-
"torch_dtype": "float32",
|
21 |
-
"transformers_version": "4.37.1",
|
22 |
-
"type_vocab_size": 2,
|
23 |
-
"use_cache": true,
|
24 |
-
"vocab_size": 20275
|
25 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gf-12L-95M-i4096_CLcancer/generation_config.json
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_from_model_config": true,
|
3 |
-
"pad_token_id": 0,
|
4 |
-
"transformers_version": "4.37.1"
|
5 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
gf-12L-95M-i4096_CLcancer/model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:2451adeed240c165634fea60ccba17063da8a2843ea9fcdcc0ce185720bf0dc2
|
3 |
-
size 152012980
|
|
|
|
|
|
|
|
gf-12L-95M-i4096_CLcancer/training_args.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:37074f3ea62a6ba0a312c38526c20c2dccbb068a2c7ee8c7c73b435dd90ab7b1
|
3 |
-
size 5048
|
|
|
|
|
|
|
|
gf-20L-95M-i4096/config.json
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"BertForMaskedLM"
|
4 |
-
],
|
5 |
-
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"classifier_dropout": null,
|
7 |
-
"hidden_act": "relu",
|
8 |
-
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size": 896,
|
10 |
-
"initializer_range": 0.02,
|
11 |
-
"intermediate_size": 1792,
|
12 |
-
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings": 4096,
|
14 |
-
"model_type": "bert",
|
15 |
-
"num_attention_heads": 14,
|
16 |
-
"num_hidden_layers": 20,
|
17 |
-
"pad_token_id": 0,
|
18 |
-
"position_embedding_type": "absolute",
|
19 |
-
"torch_dtype": "float32",
|
20 |
-
"transformers_version": "4.37.1",
|
21 |
-
"type_vocab_size": 2,
|
22 |
-
"use_cache": true,
|
23 |
-
"vocab_size": 20275
|
24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gf-20L-95M-i4096/generation_config.json
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_from_model_config": true,
|
3 |
-
"pad_token_id": 0,
|
4 |
-
"transformers_version": "4.37.1"
|
5 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
gf-20L-95M-i4096/model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:db85c081a6d392448955c7d0185e26aba74507518df991ca8c69ee9108ce8bbf
|
3 |
-
size 605292732
|
|
|
|
|
|
|
|
gf-20L-95M-i4096/training_args.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:5afed602918d6f0c4916c1b9335bcdb619bca2c6fd6c7e0dd2a86d195264b8cc
|
3 |
-
size 5048
|
|
|
|
|
|
|
|
gf-6L-30M-i2048/config.json
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"BertForMaskedLM"
|
4 |
-
],
|
5 |
-
"attention_probs_dropout_prob": 0.02,
|
6 |
-
"gradient_checkpointing": false,
|
7 |
-
"hidden_act": "relu",
|
8 |
-
"hidden_dropout_prob": 0.02,
|
9 |
-
"hidden_size": 256,
|
10 |
-
"initializer_range": 0.02,
|
11 |
-
"intermediate_size": 512,
|
12 |
-
"layer_norm_eps": 1e-12,
|
13 |
-
"max_position_embeddings": 2048,
|
14 |
-
"model_type": "bert",
|
15 |
-
"num_attention_heads": 4,
|
16 |
-
"num_hidden_layers": 6,
|
17 |
-
"pad_token_id": 0,
|
18 |
-
"position_embedding_type": "absolute",
|
19 |
-
"transformers_version": "4.6.0",
|
20 |
-
"type_vocab_size": 2,
|
21 |
-
"use_cache": true,
|
22 |
-
"vocab_size": 25426
|
23 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gf-6L-30M-i2048/model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a5e33a757431643b3697de7ef6127950cdc49e06e58d4266b3a3ab191b683f14
|
3 |
-
size 41183536
|
|
|
|
|
|
|
|
gf-6L-30M-i2048/pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:8d860e2125884475dd42bc2cd9a0e60c60808a7351241e08f2154931ffc142da
|
3 |
-
size 41216562
|
|
|
|
|
|
|
|
gf-6L-30M-i2048/training_args.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f0ec3459454205174c9d2e4d6c6930f6b0fbf3364fc03a6f4d99c4d3add2012b
|
3 |
-
size 2607
|
|
|
|
|
|
|
|