diff --git a/MANIFEST.in b/MANIFEST.in index c3875d90a1e1ee1715279ba71ae3efc1a46643e8..c2e818b69c0b7db494ed3bf2dff62517a9e2339b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,9 @@ -include geneformer/gene_median_dictionary_gc95M.pkl -include geneformer/gene_name_id_dict_gc95M.pkl -include geneformer/ensembl_mapping_dict_gc95M.pkl -include geneformer/token_dictionary_gc95M.pkl +include geneformer/gene_median_dictionary_gc104m.pkl +include geneformer/gene_name_id_dict_gc104m.pkl +include geneformer/ensembl_mapping_dict_gc104m.pkl +include geneformer/token_dictionary_gc104m.pkl + +include geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30m.pkl +include geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30m.pkl +include geneformer/gene_dictionaries_30m/ensembl_mapping_dict_gc30m.pkl +include geneformer/gene_dictionaries_30m/token_dictionary_gc30m.pkl diff --git a/README.md b/README.md index 2d1ad4375703f99e682e4293131484adeb939522..9347d80e3a3c3922c72c8ed03f44c3fa15db5e5e 100644 --- a/README.md +++ b/README.md @@ -9,35 +9,28 @@ tags: Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology. - See [our manuscript](https://rdcu.be/ddrx0) for details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of our in silico perturbation and cell and gene classification strategies. -- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model trained on ~95 million transcriptomes in April 2024 and our continual learning, multitask learning, and quantization strategies. +- See [our manuscript](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) for details of the expanded model, now trained on ~104 million transcriptomes, and our continual learning, multitask learning, and quantization strategies. - See [geneformer.readthedocs.io](https://geneformer.readthedocs.io) for documentation. # Model Description -Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. Then, in April 2024, Geneformer was pretrained on ~95 million non-cancer transcriptomes, followed by continual learning on ~14 million cancer transcriptomes to yield a cancer domain-tuned model. +Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes representing a broad range of human tissues. Geneformer V1 was originally pretrained in June 2021 on [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M), a corpus comprised of ~30 million human single cell transcriptomes. We excluded cells with high mutational burdens (e.g. malignant cells and immortalized cell lines) that could lead to substantial network rewiring without companion genome sequencing to facilitate interpretation. The current updated Geneformer V2 is pretrained on ~104 million human single cell transcriptomes (non-cancer). The cancer continual learning V2 variant was continually pretrrained on ~14 million cancer transcriptomes to yield a cancer domain-tuned model. -Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable. +Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell scaled by their expression across the entire Genecorpus (-30M for V1, -104M for V2). The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across the pretraining corpus to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by scaling them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable. The rank value encoding of each single cell’s transcriptome then proceeds through N layers of transformer encoder units, where N varies dependent on the model size. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels. We detail applications and results in [our manuscript](https://rdcu.be/ddrx0). -During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational deep learning model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets. +During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. With both zero-shot learning and fine-tuning with limited task-specific data, Geneformer consistently boosted predictive accuracy in a diverse panel of downstream tasks relevant to chromatin and network dynamics. In silico perturbation with zero-shot learning identified a novel transcription factor in cardiomyocytes that we experimentally validated to be critical to their ability to generate contractile force. In silico treatment with limited patient data revealed candidate therapeutic targets for cardiomyopathy that we experimentally validated to significantly improve the ability of cardiomyocytes to generate contractile force in an induced pluripotent stem cell (iPSC) model of the disease. Overall, Geneformer represents a foundational AI model pretrained on a large-scale corpus human single cell transcriptomes to gain a fundamental understanding of gene network dynamics that can now be democratized to a vast array of downstream tasks to accelerate discovery of key network regulators and candidate therapeutic targets. The repository includes the following pretrained models: -L=layers\ -M=millions of cells used for pretraining\ -i=input size\ -(pretraining date) +- Geneformer-V1-10M: original model trained June 2021 on ~30M human single cell transcriptomes, 10M parameters, input size 2048, vocabulary ~25K protein-coding or non-coding RNA genes +- Geneformer-V2-104M and Geneformer-V2-316M: updated model trained Dec 2024 on ~104M human single cell transcriptomes, 104M or 316M parameters, input size 4096, vocabulary ~20K protein-coding genes -- GF-6L-30M-i2048 (June 2021) -- GF-12L-30M-i2048 (June 2021) -- GF-12L-95M-i4096 (April 2024) -- GF-20L-95M-i4096 (April 2024) +The current default model in the main directory of the repository is Geneformer-V2-316M. -The current default model in the main directory of the repository is GF-12L-95M-i4096. - -The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, GF-12L-95M-i4096_CLcancer. +The repository also contains fined tuned models in the fine_tuned_models directory and the cancer-tuned model following continual learning on ~14 million cancer cells, Geneformer-V2-104M_CLcancer. # Application The pretrained Geneformer model can be used directly for zero-shot learning, for example for in silico perturbation analysis, or by fine-tuning towards the relevant downstream task, such as gene or cell state classification. @@ -87,7 +80,7 @@ For usage, see [examples](https://huggingface.co/ctheodoris/Geneformer/tree/main Please note that the fine-tuning examples are meant to be generally applicable and the input datasets and labels will vary dependent on the downstream task. Example input files for a few of the downstream tasks demonstrated in the manuscript are located within the [example_input_files directory](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files) in the dataset repository, but these only represent a few example fine-tuning applications. -Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). +Please note that GPU resources are required for efficient usage of Geneformer. Additionally, we strongly recommend tuning hyperparameters for each downstream fine-tuning application as this can significantly boost predictive potential in the downstream task (e.g. max learning rate, learning schedule, number of layers to freeze, etc.). Importantly, as usual for deep learning models, there are no uniformly applicable default hyperparameters for Geneformer. # Citations - C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. _**Nature**_, 31 May 2023. (#co-corresponding authors) diff --git a/config.json b/config.json index 86e20c35e6f257f0daeb00ebb92a0751d12d8fff..6bc648aa565eabb748a6a43ee4def5032a0d5237 100644 --- a/config.json +++ b/config.json @@ -2,22 +2,22 @@ "architectures": [ "BertForMaskedLM" ], - "attention_probs_dropout_prob": 0.02, + "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "hidden_act": "relu", - "hidden_dropout_prob": 0.02, - "hidden_size": 512, + "hidden_dropout_prob": 0.1, + "hidden_size": 1152, "initializer_range": 0.02, - "intermediate_size": 1024, + "intermediate_size": 4608, "layer_norm_eps": 1e-12, "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 8, - "num_hidden_layers": 12, + "num_attention_heads": 18, + "num_hidden_layers": 18, "pad_token_id": 0, "position_embedding_type": "absolute", "torch_dtype": "float32", - "transformers_version": "4.37.1", + "transformers_version": "4.44.2", "type_vocab_size": 2, "use_cache": true, "vocab_size": 20275 diff --git a/examples/cell_classification.ipynb b/examples/cell_classification.ipynb index 321187b9959abe460c6efc34996d6db0cf3488ed..64dd4323cb562a31b1641a000dab8aa5f59ac951 100644 --- a/examples/cell_classification.ipynb +++ b/examples/cell_classification.ipynb @@ -13,7 +13,7 @@ "id": "1792e51c-86c3-406f-be5a-273c4e4aec20", "metadata": {}, "source": [ - "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization." + "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters." ] }, { @@ -69,9 +69,7 @@ " \"seed\": 73,\n", "}\n", "\n", - "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", - "# (otherwise the Classifier will use the current default model dictionary)\n", - "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", + "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "cc = Classifier(classifier=\"cell\",\n", " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n", " filter_data=filter_data_dict,\n", @@ -80,6 +78,7 @@ " freeze_layers = 2,\n", " num_crossval_splits = 1,\n", " forward_batch_size=200,\n", + " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -264,8 +263,8 @@ " \"train\": train_ids,\n", " \"eval\": eval_ids}\n", "\n", - "# Example 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n", - "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n", + "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", + "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -450,7 +449,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/examples/extract_and_plot_cell_embeddings.ipynb b/examples/extract_and_plot_cell_embeddings.ipynb index f00388708664a1cd0c774bfa13f0c01d0ee6578d..8571064ab3a3a3f35d3fcf2d90a64fc0ebcb071f 100644 --- a/examples/extract_and_plot_cell_embeddings.ipynb +++ b/examples/extract_and_plot_cell_embeddings.ipynb @@ -18,8 +18,7 @@ "outputs": [], "source": [ "# initiate EmbExtractor\n", - "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", - "# (otherwise the EmbExtractor will use the current default model dictionary)\n", + "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "embex = EmbExtractor(model_type=\"CellClassifier\",\n", " num_classes=3,\n", " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n", @@ -28,13 +27,13 @@ " emb_label=[\"disease\",\"cell_type\"],\n", " labels_to_plot=[\"disease\"],\n", " forward_batch_size=200,\n", - " nproc=16,\n", - " token_dictionary_file=\"./gene_dictionaries_30m/token_dictionary_gc30M.pkl\") # change from current default dictionary for 30M model series\n", + " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", + " nproc=16)\n", "\n", "# extracts embedding from input data\n", "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n", - "# example dataset for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", - "embs = embex.extract_embs(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n", + "# example dataset for V1 model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n", + "embs = embex.extract_embs(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n", " \"path/to/input_data/\",\n", " \"path/to/output_directory/\",\n", " \"output_prefix\")\n" @@ -132,7 +131,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/examples/gene_classification.ipynb b/examples/gene_classification.ipynb index 284da7a1cc5846566d8b599ac2b549f6dc20f4a4..b739754a95c23c8f74f9cf4a85e05da9c2af58a8 100644 --- a/examples/gene_classification.ipynb +++ b/examples/gene_classification.ipynb @@ -13,7 +13,7 @@ "id": "79539e95-2c9c-4162-835c-f0d158abb15d", "metadata": {}, "source": [ - "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications." + "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications. Importantly, these hyperparameters do not represent uniformly applicable or recommended hyperparameters." ] }, { @@ -71,15 +71,14 @@ } ], "source": [ - "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", - "# (otherwise the Classifier will use the current default model dictionary)\n", - "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", + "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "cc = Classifier(classifier=\"gene\",\n", " gene_class_dict = gene_class_dict,\n", " max_ncells = 10_000,\n", " freeze_layers = 4,\n", " num_crossval_splits = 5,\n", " forward_batch_size=200,\n", + " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -843,8 +842,8 @@ } ], "source": [ - "# 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n", - "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n", + "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", + "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -1066,12 +1065,14 @@ } ], "source": [ + "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "cc = Classifier(classifier=\"gene\",\n", " gene_class_dict = gene_class_dict,\n", " max_ncells = 10_000,\n", " freeze_layers = 4,\n", " num_crossval_splits = 0,\n", " forward_batch_size=200,\n", + " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -1218,8 +1219,8 @@ } ], "source": [ - "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n", - "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n", + "# V1 model: https://huggingface.co/ctheodoris/Geneformer/blob/main/Geneformer-V1-10M/model.safetensors\n", + "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\", # OF NOTE: SET TO V1 MODEL ABOVE, PROVIDE V1 MODEL PATH HERE\n", " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n", " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", " output_directory=output_dir,\n", @@ -1243,7 +1244,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/examples/in_silico_perturbation.ipynb b/examples/in_silico_perturbation.ipynb index f7102617ebd36956d07ba61f8e4bccdf0719515e..00607d24415b5a6034fca467ec90c2e76ae72d43 100644 --- a/examples/in_silico_perturbation.ipynb +++ b/examples/in_silico_perturbation.ipynb @@ -39,9 +39,7 @@ "\n", "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n", "\n", - "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", - "# (otherwise the EmbExtractor will use the current default model dictionary)\n", - "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", + "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n", " num_classes=3,\n", " filter_data=filter_data_dict,\n", @@ -49,6 +47,7 @@ " emb_layer=0,\n", " summary_stat=\"exact_mean\",\n", " forward_batch_size=256,\n", + " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)\n", "\n", "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n", @@ -67,9 +66,7 @@ }, "outputs": [], "source": [ - "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", - "# (otherwise the InSilicoPerturber will use the current default model dictionary)\n", - "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", + "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "isp = InSilicoPerturber(perturb_type=\"delete\",\n", " perturb_rank_shift=None,\n", " genes_to_perturb=\"all\",\n", @@ -77,7 +74,7 @@ " anchor_gene=None,\n", " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n", " num_classes=3,\n", - " emb_mode=\"cell\",\n", + " emb_mode=\"cell\", # OF NOTE: SET TO \"CELL\" FOR V1 MODEL. FOR V2, SHOULD BE \"CLS\" (current default).\n", " cell_emb_style=\"mean_pool\",\n", " filter_data=filter_data_dict,\n", " cell_states_to_model=cell_states_to_model,\n", @@ -85,6 +82,7 @@ " max_ncells=2000,\n", " emb_layer=0,\n", " forward_batch_size=400,\n", + " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL, PROVIDE V1 MODEL PATH IN SUBSEQUENT CODE\n", " nproc=16)" ] }, @@ -97,7 +95,7 @@ "source": [ "# outputs intermediate files from in silico perturbation\n", "\n", - "isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n", + "isp.perturb_data(\"../fine_tuned_models/Geneformer-V1-10M_CellClassifier_cardiomyopathies_220224\", # example V1 fine-tuned model\n", " \"path/to/input_data\",\n", " \"path/to/isp_output_directory\",\n", " \"output_prefix\")" @@ -110,14 +108,13 @@ "metadata": {}, "outputs": [], "source": [ - "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n", - "# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n", - "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n", + "# OF NOTE: model_version should match version of model to be used (V1 or V2) to use the correct token dictionary\n", "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n", " genes_perturbed=\"all\",\n", " combos=0,\n", " anchor_gene=None,\n", - " cell_states_to_model=cell_states_to_model)" + " cell_states_to_model=cell_states_to_model,\n", + " model_version=\"V1\", # OF NOTE: SET TO V1 MODEL SINCE V1 WAS USED FOR IN SILICO PERTURBATION ABOVE)" ] }, { @@ -151,7 +148,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/examples/multitask_cell_classification.ipynb b/examples/multitask_cell_classification.ipynb index b3f13b7477c7fb8797bf871b90f943877fb61029..998e678d7eb812dbf6e5e3764b982e9d6620fd63 100644 --- a/examples/multitask_cell_classification.ipynb +++ b/examples/multitask_cell_classification.ipynb @@ -286,7 +286,7 @@ " filter_data_dict=filter_data_dict,\n", " max_ncells=1000, # Number of cells to extract embeddings for\n", " emb_layer=0, # Use the second to last layer\n", - " emb_mode = \"cls\",\n", + " emb_mode = \"cls\", # Use CLS token embedding for V2 model\n", " summary_stat=\"exact_mean\",\n", " forward_batch_size=8, # Adjust based on available GPU memory\n", " nproc=4\n", @@ -324,7 +324,7 @@ " perturb_type=perturb_type,\n", " genes_to_perturb=\"all\", # Perturb all genes\n", " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n", - " emb_mode=\"cls\", # Use CLS token embedding\n", + " emb_mode=\"cls\", # Use CLS token embedding for V2 model\n", " cell_states_to_model=cell_states_to_model,\n", " state_embs_dict=state_embs_dict,\n", " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n", @@ -412,7 +412,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/examples/tokenizing_scRNAseq_data.ipynb b/examples/tokenizing_scRNAseq_data.ipynb index 58c629a166529b066ba3615c16a26e59dd46295f..7f7331fb63d46567465ddcc6cea5560e09a40e24 100644 --- a/examples/tokenizing_scRNAseq_data.ipynb +++ b/examples/tokenizing_scRNAseq_data.ipynb @@ -34,12 +34,8 @@ "metadata": {}, "source": [ "**********************************************************************************************************\n", - "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n", - "#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n", - "\n", - "#### ADDITIONALLY:\n", - "#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n", - "#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048." + "#### OF NOTE: Please ensure the correct token dictionary, gene median file, special token setting, and model input size is used for the correct model version.\n", + "#### Current defaults are for V2 model series. To auto-select the correct settings for V1, set model_version argument to \"V1\"." ] }, { @@ -59,7 +55,7 @@ "metadata": {}, "outputs": [], "source": [ - "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n", + "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16) # for V1 model, set model_version=\"V1\"\n", "tk.tokenize_data(\"loom_data_directory\", \n", " \"output_directory\", \n", " \"output_prefix\", \n", @@ -83,7 +79,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json deleted file mode 100755 index bc8099f84af0bd3e35d700a7135dd417e38f6bea..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "architectures": [ - "BertForMaskedLM" - ], - "attention_probs_dropout_prob": 0.02, - "classifier_dropout": null, - "hidden_act": "relu", - "hidden_dropout_prob": 0.02, - "hidden_size": 512, - "initializer_range": 0.02, - "intermediate_size": 1024, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 4096, - "model_type": "bert", - "num_attention_heads": 8, - "num_hidden_layers": 12, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "torch_dtype": "float32", - "transformers_version": "4.37.2", - "type_vocab_size": 2, - "use_cache": true, - "vocab_size": 20275 -} diff --git a/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin deleted file mode 100755 index 87625b1b8fe02c6aa0fc3ffd8c746275570e589d..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4 -size 152363342 diff --git a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json deleted file mode 100644 index a97e9ed8ae1716c9a513469ac9fb13762af12379..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "_name_or_path": "/n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/", - "architectures": [ - "BertForSequenceClassification" - ], - "attention_probs_dropout_prob": 0.02, - "gradient_checkpointing": false, - "hidden_act": "relu", - "hidden_dropout_prob": 0.02, - "hidden_size": 256, - "id2label": { - "0": "LABEL_0", - "1": "LABEL_1", - "2": "LABEL_2" - }, - "initializer_range": 0.02, - "intermediate_size": 512, - "label2id": { - "LABEL_0": 0, - "LABEL_1": 1, - "LABEL_2": 2 - }, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 2048, - "model_type": "bert", - "num_attention_heads": 4, - "num_hidden_layers": 6, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "problem_type": "single_label_classification", - "transformers_version": "4.6.0", - "type_vocab_size": 2, - "use_cache": true, - "vocab_size": 25426 -} diff --git a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt deleted file mode 100644 index 0661685e5939e9534ad391eac3904fa8b38bc4a4..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3ced328122d57a847fc3914732337674500e259a82e64437c67b4954ac2f4e07 -size 73720721 diff --git a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin deleted file mode 100644 index 1d8981e95688798846e2ab155154532c0d5c060d..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:12ead3bad8cf4b853bac87eadeb79c9308ae492e9d29f32da1a2c85e8586108d -size 41115113 diff --git a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth deleted file mode 100644 index e94f76fed717ff05f139167be936887f93fc1162..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dd8c0a739c2fe6a9ab4bb8f4a62ad8d7b879efcdceb5376b128a2040ff1bbe62 -size 14657 diff --git a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt deleted file mode 100644 index 8b981b11a1b078de5ae15690cf383dfd020f2546..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3d0797845afdae765a74ddab7966e0e1837617fd8171af8ee6aef9dedce248f2 -size 623 diff --git a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json deleted file mode 100644 index 1a8a9258268b72e5f9c3388e83fade166c2c1050..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json +++ /dev/null @@ -1,150 +0,0 @@ -{ - "best_metric": 0.39658036828041077, - "best_model_checkpoint": "/n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/models/220224_geneformer_27M_SequenceClassifier_tuning_hCMdCM_L2048_B12_LR1e-05_LScosine_WU500_E1_Oadamw_F2/run-8429a330/checkpoint-7020", - "epoch": 0.9, - "global_step": 7020, - "is_hyper_param_search": true, - "is_local_process_zero": true, - "is_world_process_zero": true, - "log_history": [ - { - "epoch": 0.1, - "learning_rate": 0.00034606438343856935, - "loss": 0.911, - "step": 780 - }, - { - "epoch": 0.1, - "eval_accuracy": 0.4531576503366612, - "eval_loss": 1.4550466537475586, - "eval_runtime": 66.5164, - "eval_samples_per_second": 259.004, - "step": 780 - }, - { - "epoch": 0.2, - "learning_rate": 0.0006921287668771387, - "loss": 0.6273, - "step": 1560 - }, - { - "epoch": 0.2, - "eval_accuracy": 0.5953680055723242, - "eval_loss": 0.846651554107666, - "eval_runtime": 66.1267, - "eval_samples_per_second": 260.53, - "step": 1560 - }, - { - "epoch": 0.3, - "learning_rate": 0.0007330550166223805, - "loss": 0.5592, - "step": 2340 - }, - { - "epoch": 0.3, - "eval_accuracy": 0.5935105641978176, - "eval_loss": 1.0599186420440674, - "eval_runtime": 66.2608, - "eval_samples_per_second": 260.003, - "step": 2340 - }, - { - "epoch": 0.4, - "learning_rate": 0.0006283471571048975, - "loss": 0.3714, - "step": 3120 - }, - { - "epoch": 0.4, - "eval_accuracy": 0.686324587880195, - "eval_loss": 1.184874415397644, - "eval_runtime": 66.1411, - "eval_samples_per_second": 260.473, - "step": 3120 - }, - { - "epoch": 0.5, - "learning_rate": 0.0005236392975874146, - "loss": 0.2976, - "step": 3900 - }, - { - "epoch": 0.5, - "eval_accuracy": 0.7681100534014396, - "eval_loss": 0.6318939328193665, - "eval_runtime": 66.3309, - "eval_samples_per_second": 259.728, - "step": 3900 - }, - { - "epoch": 0.6, - "learning_rate": 0.0004189314380699318, - "loss": 0.2564, - "step": 4680 - }, - { - "epoch": 0.6, - "eval_accuracy": 0.7807058277223126, - "eval_loss": 0.7283642888069153, - "eval_runtime": 66.3416, - "eval_samples_per_second": 259.686, - "step": 4680 - }, - { - "epoch": 0.7, - "learning_rate": 0.0003142235785524487, - "loss": 0.2336, - "step": 5460 - }, - { - "epoch": 0.7, - "eval_accuracy": 0.8563965637334572, - "eval_loss": 0.5184123516082764, - "eval_runtime": 66.3416, - "eval_samples_per_second": 259.686, - "step": 5460 - }, - { - "epoch": 0.8, - "learning_rate": 0.0002095157190349659, - "loss": 0.1731, - "step": 6240 - }, - { - "epoch": 0.8, - "eval_accuracy": 0.8288832133735778, - "eval_loss": 0.5823884010314941, - "eval_runtime": 66.1535, - "eval_samples_per_second": 260.425, - "step": 6240 - }, - { - "epoch": 0.9, - "learning_rate": 0.00010480785951748295, - "loss": 0.1451, - "step": 7020 - }, - { - "epoch": 0.9, - "eval_accuracy": 0.886812166241003, - "eval_loss": 0.39658036828041077, - "eval_runtime": 66.3555, - "eval_samples_per_second": 259.632, - "step": 7020 - } - ], - "max_steps": 7800, - "num_train_epochs": 1, - "total_flos": 0, - "trial_name": null, - "trial_params": { - "learning_rate": 0.0008039341830649843, - "lr_scheduler_type": "polynomial", - "num_train_epochs": 1, - "per_device_train_batch_size": 12, - "seed": 73.15243080311434, - "warmup_steps": 1812.6785581609881, - "weight_decay": 0.2588277764570262 - } -} diff --git a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin deleted file mode 100644 index 080126c3d7536675c14b51082b1534623ff17acb..0000000000000000000000000000000000000000 --- a/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4ffee119596c99b50a422b2f80103f4c44f7e25c2ea0e457fe224bad59f1f955 -size 2607 diff --git a/geneformer/__init__.py b/geneformer/__init__.py index 52d43619d06f2a7c019b480d1958a82d287d26ff..38cfecec8ec7924d7c74bf195165417b4c891829 100644 --- a/geneformer/__init__.py +++ b/geneformer/__init__.py @@ -4,10 +4,15 @@ from pathlib import Path warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip -GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl" -TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl" -ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl" -ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl" +GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc104M.pkl" +TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc104M.pkl" +ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc104M.pkl" +ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc104M.pkl" + +GENE_MEDIAN_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl" +TOKEN_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/token_dictionary_gc30M.pkl" +ENSEMBL_DICTIONARY_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl" +ENSEMBL_MAPPING_FILE_30M = Path(__file__).parent / "gene_dictionaries_30m/ensembl_mapping_dict_gc30M.pkl" from . import ( collator_for_classification, diff --git a/geneformer/classifier.py b/geneformer/classifier.py index 9fbe8c5fbc4a35214691e5b363341a47919d5e0e..53d0885b56bf3b2fe6fb398668bc9ae4ebb197ea 100644 --- a/geneformer/classifier.py +++ b/geneformer/classifier.py @@ -92,6 +92,7 @@ class Classifier: "no_eval": {bool}, "stratify_splits_col": {None, str}, "forward_batch_size": {int}, + "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "nproc": {int}, "ngpu": {int}, @@ -115,6 +116,7 @@ class Classifier: stratify_splits_col=None, no_eval=False, forward_batch_size=100, + model_version="V2", token_dictionary_file=None, nproc=4, ngpu=1, @@ -191,6 +193,9 @@ class Classifier: | Otherwise, will perform eval during training. forward_batch_size : int | Batch size for forward pass (for evaluation, not training). + model_version : str + | To auto-select settings for model version other than current default. + | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : None, str | Default is to use token dictionary file from Geneformer | Otherwise, will load custom gene token dictionary. @@ -225,14 +230,20 @@ class Classifier: self.stratify_splits_col = stratify_splits_col self.no_eval = no_eval self.forward_batch_size = forward_batch_size + self.model_version = model_version self.token_dictionary_file = token_dictionary_file self.nproc = nproc self.ngpu = ngpu + if self.model_version == "V1": + from . import TOKEN_DICTIONARY_FILE_30M + self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M + if self.training_args is None: logger.warning( "Hyperparameter tuning is highly recommended for optimal results. " - "No training_args provided; using default hyperparameters." + "No training_args provided; using default hyperparameters. " + "Please note: these defaults are not recommended to be used uniformly across tasks." ) self.validate_options() @@ -1319,7 +1330,7 @@ class Classifier: ##### Evaluate the model ##### labels = id_class_dict.keys() y_pred, y_true, logits_list = eu.classifier_predict( - model, self.classifier, eval_data, self.forward_batch_size + model, self.classifier, eval_data, self.forward_batch_size, self.gene_token_dict ) conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics( y_pred, y_true, logits_list, num_classes, labels diff --git a/geneformer/emb_extractor.py b/geneformer/emb_extractor.py index 2ef1103dee2e492f87f751a51d0f4f12b1ce87d0..c5d6bf4df38b7a34a1b0decd0e474ae35bfdee37 100644 --- a/geneformer/emb_extractor.py +++ b/geneformer/emb_extractor.py @@ -402,6 +402,7 @@ class EmbExtractor: "emb_label": {None, list}, "labels_to_plot": {None, list}, "forward_batch_size": {int}, + "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "nproc": {int}, "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"}, @@ -422,6 +423,7 @@ class EmbExtractor: forward_batch_size=100, nproc=4, summary_stat=None, + model_version="V2", token_dictionary_file=None, ): """ @@ -472,6 +474,9 @@ class EmbExtractor: | If mean or median, outputs only approximated mean or median embedding of input data. | Non-exact recommended if encountering memory constraints while generating goal embedding positions. | Non-exact is slower but more memory-efficient. + model_version : str + | To auto-select settings for model version other than current default. + | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Default is the Geneformer token dictionary | Path to pickle file containing token dictionary (Ensembl ID:token). @@ -502,6 +507,7 @@ class EmbExtractor: self.emb_layer = emb_layer self.emb_label = emb_label self.labels_to_plot = labels_to_plot + self.model_version = model_version self.token_dictionary_file = token_dictionary_file self.forward_batch_size = forward_batch_size self.nproc = nproc @@ -512,6 +518,15 @@ class EmbExtractor: self.summary_stat = summary_stat self.exact_summary_stat = None + if self.model_version == "V1": + from . import TOKEN_DICTIONARY_FILE_30M + self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M + if self.emb_mode == "cls": + self.emb_mode = "cell" + logger.warning( + "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a token." + ) + self.validate_options() # load token dictionary (Ensembl IDs:token) diff --git a/geneformer/ensembl_mapping_dict_gc95M.pkl b/geneformer/ensembl_mapping_dict_gc95M.pkl deleted file mode 100644 index 927b80d0145a186925b04b62dac2e1141db88392..0000000000000000000000000000000000000000 --- a/geneformer/ensembl_mapping_dict_gc95M.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0819bcbd869cfa14279449b037eb9ed1d09a91310e77bd1a19d927465030e95c -size 3957652 diff --git a/geneformer/evaluation_utils.py b/geneformer/evaluation_utils.py index e4bbc8326d33b0de0a62778e6cde0d0c4bd86b25..1f8970f3e37d2b599cda0054112eb9d1d98f6f50 100644 --- a/geneformer/evaluation_utils.py +++ b/geneformer/evaluation_utils.py @@ -20,20 +20,15 @@ from sklearn.metrics import ( ) from tqdm.auto import trange -from . import TOKEN_DICTIONARY_FILE from .emb_extractor import make_colorbar logger = logging.getLogger(__name__) -def preprocess_classifier_batch(cell_batch, max_len, label_name): +def preprocess_classifier_batch(cell_batch, max_len, label_name, gene_token_dict): if max_len is None: max_len = max([len(i) for i in cell_batch["input_ids"]]) - # load token dictionary (Ensembl IDs:token) - with open(TOKEN_DICTIONARY_FILE, "rb") as f: - gene_token_dict = pickle.load(f) - def pad_label_example(example): example[label_name] = np.pad( example[label_name], @@ -81,7 +76,7 @@ def py_softmax(vector): return e / e.sum() -def classifier_predict(model, classifier_type, evalset, forward_batch_size): +def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene_token_dict): if classifier_type == "gene": label_name = "labels" elif classifier_type == "cell": @@ -104,7 +99,7 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size): max_range = min(i + forward_batch_size, evalset_len) batch_evalset = evalset.select([i for i in range(i, max_range)]) padded_batch = preprocess_classifier_batch( - batch_evalset, max_evalset_len, label_name + batch_evalset, max_evalset_len, label_name, gene_token_dict ) padded_batch.set_format(type="torch") diff --git a/geneformer/gene_median_dictionary_gc95M.pkl b/geneformer/gene_median_dictionary_gc95M.pkl deleted file mode 100644 index 76b1e84597b859f1ab323038ed7d1513c38b14e4..0000000000000000000000000000000000000000 --- a/geneformer/gene_median_dictionary_gc95M.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a51c53f6a771d64508dfaf61529df70e394c53bd20856926117ae5d641a24bf5 -size 1512661 diff --git a/geneformer/gene_name_id_dict_gc95M.pkl b/geneformer/gene_name_id_dict_gc95M.pkl deleted file mode 100644 index f98c94dd0c7ff50b6b74691d75c66be5affc9fa1..0000000000000000000000000000000000000000 --- a/geneformer/gene_name_id_dict_gc95M.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fabfa0c2f49c598c59ae432a32c3499a5908c033756c663b5e0cddf58deea8e1 -size 1660882 diff --git a/geneformer/in_silico_perturber.py b/geneformer/in_silico_perturber.py index 275244f771e344435734f9ef19f3749e294f0d2c..b7c419bf64246a70f3efaf010fca7b339fe70cc2 100644 --- a/geneformer/in_silico_perturber.py +++ b/geneformer/in_silico_perturber.py @@ -72,6 +72,7 @@ class InSilicoPerturber: "max_ncells": {None, int}, "cell_inds_to_perturb": {"all", dict}, "emb_layer": {-1, 0}, + "model_version": {"V1", "V2"}, "token_dictionary_file": {None, str}, "forward_batch_size": {int}, "nproc": {int}, @@ -96,6 +97,7 @@ class InSilicoPerturber: emb_layer=-1, forward_batch_size=100, nproc=4, + model_version="V2", token_dictionary_file=None, clear_mem_ncells=1000, ): @@ -184,6 +186,9 @@ class InSilicoPerturber: | Batch size for forward pass. nproc : int | Number of CPU processes to use. + model_version : str + | To auto-select settings for model version other than current default. + | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl ID:token). clear_mem_ncells : int @@ -224,9 +229,24 @@ class InSilicoPerturber: self.emb_layer = emb_layer self.forward_batch_size = forward_batch_size self.nproc = nproc + self.model_version = model_version self.token_dictionary_file = token_dictionary_file self.clear_mem_ncells = clear_mem_ncells + if self.model_version == "V1": + from . import TOKEN_DICTIONARY_FILE_30M + self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M + if self.emb_mode == "cls": + self.emb_mode = "cell" + logger.warning( + "model_version selected as V1 so changing emb_mode from 'cls' to 'cell' as V1 models do not have a token." + ) + if self.emb_mode == "cls_and_gene": + self.emb_mode = "cell_and_gene" + logger.warning( + "model_version selected as V1 so changing emb_mode from 'cls_and_gene' to 'cell_and_gene' as V1 models do not have a token." + ) + self.validate_options() # load token dictionary (Ensembl IDs:token) diff --git a/geneformer/in_silico_perturber_stats.py b/geneformer/in_silico_perturber_stats.py index 9ec98a8caee4e4ca623c5ecc7c18c36210806cce..b6e472002d7dd32360cececaf40ef06f1e791418 100644 --- a/geneformer/in_silico_perturber_stats.py +++ b/geneformer/in_silico_perturber_stats.py @@ -676,6 +676,7 @@ class InSilicoPerturberStats: "anchor_gene": {None, str}, "cell_states_to_model": {None, dict}, "pickle_suffix": {None, str}, + "model_version": {"V1", "V2"}, } def __init__( @@ -686,6 +687,7 @@ class InSilicoPerturberStats: anchor_gene=None, cell_states_to_model=None, pickle_suffix="_raw.pickle", + model_version="V2", token_dictionary_file=TOKEN_DICTIONARY_FILE, gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE, ): @@ -713,7 +715,7 @@ class InSilicoPerturberStats: | analyzes data for anchor gene perturbed in combination with each other gene. | However, if combos=0 and anchor_gene="ENSG00000136574": | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene. - cell_states_to_model: None, dict + cell_states_to_model : None, dict | Cell states to model if testing perturbations that achieve goal state change. | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states | state_key: key specifying name of column in .dataset that defines the start/goal states @@ -724,6 +726,9 @@ class InSilicoPerturberStats: | "start_state": "dcm", | "goal_state": "nf", | "alt_states": ["hcm", "other1", "other2"]} + model_version : str + | To auto-select settings for model version other than current default. + | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl ID:token). gene_name_id_dictionary_file : Path @@ -736,9 +741,15 @@ class InSilicoPerturberStats: self.anchor_gene = anchor_gene self.cell_states_to_model = cell_states_to_model self.pickle_suffix = pickle_suffix + self.model_version = model_version self.validate_options() + if self.model_version == "V1": + from . import ENSEMBL_DICTIONARY_FILE_30M, TOKEN_DICTIONARY_FILE_30M + token_dictionary_file=TOKEN_DICTIONARY_FILE_30M + gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE_30M + # load token dictionary (Ensembl IDs:token) with open(token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) diff --git a/geneformer/perturber_utils.py b/geneformer/perturber_utils.py index 26190fb4a89b59ef362d93eaedfa65934b269356..0c260d313899f3228339e6bcd4147fee98f3a791 100644 --- a/geneformer/perturber_utils.py +++ b/geneformer/perturber_utils.py @@ -17,11 +17,6 @@ from transformers import ( BitsAndBytesConfig, ) -from . import ( - TOKEN_DICTIONARY_FILE, - ENSEMBL_DICTIONARY_FILE, -) - logger = logging.getLogger(__name__) @@ -127,7 +122,10 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False): output_hidden_states = (mode == "eval") # Quantization logic - if quantize: + if isinstance(quantize, dict): + quantize_config = quantize.get("bnb_config", None) + peft_config = quantize.get("peft_config", None) + elif quantize: if inference_only: quantize_config = BitsAndBytesConfig(load_in_8bit=True) peft_config = None @@ -138,19 +136,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False): bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) - lora_config_params = { - "lora_alpha": 128, - "lora_dropout": 0.1, - "r": 64, - "bias": "none" - } - - # Try with TokenClassification first, fallback to TOKEN_CLS if needed try: - peft_config = LoraConfig(**lora_config_params, task_type="TokenClassification") - except ValueError: - # Some versions use TOKEN_CLS instead of TokenClassification - peft_config = LoraConfig(**lora_config_params, task_type="TOKEN_CLS") + peft_config = LoraConfig( + lora_alpha=128, + lora_dropout=0.1, + r=64, + bias="none", + task_type="TokenClassification", + ) + except ValueError as e: + peft_config = LoraConfig( + lora_alpha=128, + lora_dropout=0.1, + r=64, + bias="none", + task_type="TOKEN_CLS", + ) else: quantize_config = None peft_config = None @@ -187,14 +188,22 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False): model.eval() # Handle device placement and PEFT + adapter_config_path = os.path.join(model_directory, "adapter_config.json") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not quantize: # Only move non-quantized models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) + elif os.path.exists(adapter_config_path): + # If adapter files exist, load them into the model using PEFT's from_pretrained + model = PeftModel.from_pretrained(model, model_directory) + model = model.to(device) + print("loading lora weights") elif peft_config: # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf) model.enable_input_require_grads() model = get_peft_model(model, peft_config) + model = model.to(device) return model @@ -883,50 +892,4 @@ def validate_cell_states_to_model(cell_states_to_model): "'goal_state': 'nf', " "'alt_states': ['hcm', 'other1', 'other2']}" ) - raise - - -class GeneIdHandler: - def __init__(self, raise_errors=False): - def invert_dict(dict_obj): - return {v: k for k, v in dict_obj.items()} - - self.raise_errors = raise_errors - - with open(TOKEN_DICTIONARY_FILE, "rb") as f: - self.gene_token_dict = pickle.load(f) - self.token_gene_dict = invert_dict(self.gene_token_dict) - - with open(ENSEMBL_DICTIONARY_FILE, "rb") as f: - self.id_gene_dict = pickle.load(f) - self.gene_id_dict = invert_dict(self.id_gene_dict) - - def ens_to_token(self, ens_id): - if not self.raise_errors: - return self.gene_token_dict.get(ens_id, ens_id) - else: - return self.gene_token_dict[ens_id] - - def token_to_ens(self, token): - if not self.raise_errors: - return self.token_gene_dict.get(token, token) - else: - return self.token_gene_dict[token] - - def ens_to_symbol(self, ens_id): - if not self.raise_errors: - return self.gene_id_dict.get(ens_id, ens_id) - else: - return self.gene_id_dict[ens_id] - - def symbol_to_ens(self, symbol): - if not self.raise_errors: - return self.id_gene_dict.get(symbol, symbol) - else: - return self.id_gene_dict[symbol] - - def token_to_symbol(self, token): - return self.ens_to_symbol(self.token_to_ens(token)) - - def symbol_to_token(self, symbol): - return self.ens_to_token(self.symbol_to_ens(symbol)) + raise \ No newline at end of file diff --git a/geneformer/token_dictionary_gc95M.pkl b/geneformer/token_dictionary_gc95M.pkl deleted file mode 100644 index b56e406e79c255328f84d9ca00c5c3da2dd04811..0000000000000000000000000000000000000000 --- a/geneformer/token_dictionary_gc95M.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:67c445f4385127adfc48dcc072320cd65d6822829bf27dd38070e6e787bc597f -size 425590 diff --git a/geneformer/tokenizer.py b/geneformer/tokenizer.py index b460f028c9d85630b34722a290df6dd40f8908aa..98040b0179196b93b44d4c1924dc455973de4e87 100644 --- a/geneformer/tokenizer.py +++ b/geneformer/tokenizer.py @@ -32,9 +32,7 @@ Geneformer tokenizer. | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer. -| OF NOTE: Take care that the correct token dictionary and gene median file is used for the correct model. - -| OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048. +| OF NOTE: Use model_version to auto-select settings for model version other than current default. For V1 model series (original Geneformer pretrained in 2021 on ~30M cells), one must use correct corresponding token dictionary and gene median file, set special_token to False, and set model_input_size to 2048. This argument enables auto-selection of these settings. (For V2 model series, special_token must be True and model_input_size is 4096.) """ @@ -299,6 +297,7 @@ class TranscriptomeTokenizer: model_input_size=4096, special_token=True, collapse_gene_ids=True, + model_version="V2", gene_median_file=GENE_MEDIAN_FILE, token_dictionary_file=TOKEN_DICTIONARY_FILE, gene_mapping_file=ENSEMBL_MAPPING_FILE, @@ -318,15 +317,18 @@ class TranscriptomeTokenizer: | Chunk size for anndata tokenizer. model_input_size : int = 4096 | Max input size of model to truncate input to. - | For the 30M model series, should be 2048. For the 95M model series, should be 4096. + | For the V1 model series, should be 2048. For the V2 model series, should be 4096. special_token : bool = True | Adds CLS token before and EOS token after rank value encoding. - | For the 30M model series, should be False. For the 95M model series, should be True. + | For the V1 model series, should be False. For the V2 model series, should be True. collapse_gene_ids : bool = True | Whether to collapse gene IDs based on gene mapping dictionary. + model_version : str + | To auto-select settings for model version other than current default. + | Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells gene_median_file : Path | Path to pickle file containing dictionary of non-zero median - | gene expression values across Genecorpus-30M. + | gene expression values across Genecorpus. token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl IDs:token). gene_mapping_file : None, Path @@ -348,8 +350,22 @@ class TranscriptomeTokenizer: # add CLS and EOS tokens self.special_token = special_token + # CHANGE DEFAULTS TO BE FOR MODEL OTHER THAN CURRENT + self.model_version = model_version + if self.model_version not in ["V1","V2"]: + logger.error( + "Unrecognized model version. Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells." + ) + elif self.model_version == "V1": + self.model_input_size = 2048 + self.special_token = False + from . import ENSEMBL_MAPPING_FILE_30M, GENE_MEDIAN_FILE_30M, TOKEN_DICTIONARY_FILE_30M + gene_median_file = GENE_MEDIAN_FILE_30M + token_dictionary_file = TOKEN_DICTIONARY_FILE_30M + gene_mapping_file = ENSEMBL_MAPPING_FILE_30M + # load dictionary of gene normalization factors - # (non-zero median value of expression across Genecorpus-30M) + # (non-zero median value of expression across Genecorpus) with open(gene_median_file, "rb") as f: self.gene_median_dict = pickle.load(f) @@ -372,7 +388,7 @@ class TranscriptomeTokenizer: "" in self.gene_token_dict.keys() ): logger.warning( - " and are in gene_token_dict but special_token = False. Please note that for 95M model series, special_token should be True." + " and are in gene_token_dict but special_token = False. Please note that for V2 model series, special_token should be True." ) # if collapsing duplicate gene IDs diff --git a/generation_config.json b/generation_config.json index 6f690c1f39b5b262e6b898b8891afd9d44978f11..0786a9f4dc0de68ee18cbf78399931b05fbefee7 100644 --- a/generation_config.json +++ b/generation_config.json @@ -1,5 +1,5 @@ { "_from_model_config": true, "pad_token_id": 0, - "transformers_version": "4.37.1" + "transformers_version": "4.44.2" } diff --git a/gf-12L-30M-i2048/config.json b/gf-12L-30M-i2048/config.json deleted file mode 100644 index 52a12424cea85facdf0ca0c507908506daae7ea7..0000000000000000000000000000000000000000 --- a/gf-12L-30M-i2048/config.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "architectures": [ - "BertForMaskedLM" - ], - "attention_probs_dropout_prob": 0.02, - "gradient_checkpointing": false, - "hidden_act": "relu", - "hidden_dropout_prob": 0.02, - "hidden_size": 512, - "initializer_range": 0.02, - "intermediate_size": 1024, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 2048, - "model_type": "bert", - "num_attention_heads": 8, - "num_hidden_layers": 12, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "transformers_version": "4.6.0", - "type_vocab_size": 2, - "use_cache": true, - "vocab_size": 25426 -} diff --git a/gf-12L-30M-i2048/pytorch_model.bin b/gf-12L-30M-i2048/pytorch_model.bin deleted file mode 100644 index d706ef2ff77fd6809a91034e6ed24af0e1b33999..0000000000000000000000000000000000000000 --- a/gf-12L-30M-i2048/pytorch_model.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615 -size 158467410 diff --git a/gf-12L-30M-i2048/training_args.bin b/gf-12L-30M-i2048/training_args.bin deleted file mode 100644 index 346383caaaa3b555cb6fcd8de8e4982ebf2a50d5..0000000000000000000000000000000000000000 --- a/gf-12L-30M-i2048/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:259cf6067211e24e198690d00f0a222ee5550ad57e23d04ced0d0ca2e1b3738e -size 2607 diff --git a/gf-12L-95M-i4096/config.json b/gf-12L-95M-i4096/config.json deleted file mode 100755 index 86e20c35e6f257f0daeb00ebb92a0751d12d8fff..0000000000000000000000000000000000000000 --- a/gf-12L-95M-i4096/config.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "architectures": [ - "BertForMaskedLM" - ], - "attention_probs_dropout_prob": 0.02, - "classifier_dropout": null, - "hidden_act": "relu", - "hidden_dropout_prob": 0.02, - "hidden_size": 512, - "initializer_range": 0.02, - "intermediate_size": 1024, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 4096, - "model_type": "bert", - "num_attention_heads": 8, - "num_hidden_layers": 12, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "torch_dtype": "float32", - "transformers_version": "4.37.1", - "type_vocab_size": 2, - "use_cache": true, - "vocab_size": 20275 -} diff --git a/gf-12L-95M-i4096/generation_config.json b/gf-12L-95M-i4096/generation_config.json deleted file mode 100755 index 6f690c1f39b5b262e6b898b8891afd9d44978f11..0000000000000000000000000000000000000000 --- a/gf-12L-95M-i4096/generation_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "_from_model_config": true, - "pad_token_id": 0, - "transformers_version": "4.37.1" -} diff --git a/gf-12L-95M-i4096/model.safetensors b/gf-12L-95M-i4096/model.safetensors deleted file mode 100755 index 1069352219a29bed65fa8e13feb77004128174fa..0000000000000000000000000000000000000000 --- a/gf-12L-95M-i4096/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c -size 152012980 diff --git a/gf-12L-95M-i4096/training_args.bin b/gf-12L-95M-i4096/training_args.bin deleted file mode 100755 index 18802f485a03e0262866d1ef7a3e4748a3b14ed3..0000000000000000000000000000000000000000 --- a/gf-12L-95M-i4096/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d -size 4920 diff --git a/gf-12L-95M-i4096_CLcancer/config.json b/gf-12L-95M-i4096_CLcancer/config.json deleted file mode 100755 index a7793eb2ea27b28f1f4c5b9974d30c98b4afe8a6..0000000000000000000000000000000000000000 --- a/gf-12L-95M-i4096_CLcancer/config.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "_name_or_path": "/gladstone/theodoris/lab/pretrained_models/encoder/240402_194213_geneformer_94M_L12_emb512_SL4096_E3_B4_LR0.0005_LScosine_WU5000_Oadamw_DS8/models", - "architectures": [ - "BertForMaskedLM" - ], - "attention_probs_dropout_prob": 0.02, - "classifier_dropout": null, - "hidden_act": "relu", - "hidden_dropout_prob": 0.02, - "hidden_size": 512, - "initializer_range": 0.02, - "intermediate_size": 1024, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 4096, - "model_type": "bert", - "num_attention_heads": 8, - "num_hidden_layers": 12, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "torch_dtype": "float32", - "transformers_version": "4.37.1", - "type_vocab_size": 2, - "use_cache": true, - "vocab_size": 20275 -} diff --git a/gf-12L-95M-i4096_CLcancer/generation_config.json b/gf-12L-95M-i4096_CLcancer/generation_config.json deleted file mode 100755 index 6f690c1f39b5b262e6b898b8891afd9d44978f11..0000000000000000000000000000000000000000 --- a/gf-12L-95M-i4096_CLcancer/generation_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "_from_model_config": true, - "pad_token_id": 0, - "transformers_version": "4.37.1" -} diff --git a/gf-12L-95M-i4096_CLcancer/model.safetensors b/gf-12L-95M-i4096_CLcancer/model.safetensors deleted file mode 100755 index cc620ee4b4243b7ab6d83ad518563e1425eab45b..0000000000000000000000000000000000000000 --- a/gf-12L-95M-i4096_CLcancer/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2451adeed240c165634fea60ccba17063da8a2843ea9fcdcc0ce185720bf0dc2 -size 152012980 diff --git a/gf-12L-95M-i4096_CLcancer/training_args.bin b/gf-12L-95M-i4096_CLcancer/training_args.bin deleted file mode 100755 index 1669f5848710ca4a53db6e118e50b816f85381b7..0000000000000000000000000000000000000000 --- a/gf-12L-95M-i4096_CLcancer/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:37074f3ea62a6ba0a312c38526c20c2dccbb068a2c7ee8c7c73b435dd90ab7b1 -size 5048 diff --git a/gf-20L-95M-i4096/config.json b/gf-20L-95M-i4096/config.json deleted file mode 100755 index db949ba1ae442ad3b9e52fd8b7922c6b936ef98c..0000000000000000000000000000000000000000 --- a/gf-20L-95M-i4096/config.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "architectures": [ - "BertForMaskedLM" - ], - "attention_probs_dropout_prob": 0.02, - "classifier_dropout": null, - "hidden_act": "relu", - "hidden_dropout_prob": 0.02, - "hidden_size": 896, - "initializer_range": 0.02, - "intermediate_size": 1792, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 4096, - "model_type": "bert", - "num_attention_heads": 14, - "num_hidden_layers": 20, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "torch_dtype": "float32", - "transformers_version": "4.37.1", - "type_vocab_size": 2, - "use_cache": true, - "vocab_size": 20275 -} diff --git a/gf-20L-95M-i4096/generation_config.json b/gf-20L-95M-i4096/generation_config.json deleted file mode 100755 index 6f690c1f39b5b262e6b898b8891afd9d44978f11..0000000000000000000000000000000000000000 --- a/gf-20L-95M-i4096/generation_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "_from_model_config": true, - "pad_token_id": 0, - "transformers_version": "4.37.1" -} diff --git a/gf-20L-95M-i4096/model.safetensors b/gf-20L-95M-i4096/model.safetensors deleted file mode 100755 index 37212863afb501a17425dd48766d71d534537d24..0000000000000000000000000000000000000000 --- a/gf-20L-95M-i4096/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:db85c081a6d392448955c7d0185e26aba74507518df991ca8c69ee9108ce8bbf -size 605292732 diff --git a/gf-20L-95M-i4096/training_args.bin b/gf-20L-95M-i4096/training_args.bin deleted file mode 100755 index 3db61b0b99d299afb7c4a237d2b531baa253e5d3..0000000000000000000000000000000000000000 --- a/gf-20L-95M-i4096/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5afed602918d6f0c4916c1b9335bcdb619bca2c6fd6c7e0dd2a86d195264b8cc -size 5048 diff --git a/gf-6L-30M-i2048/config.json b/gf-6L-30M-i2048/config.json deleted file mode 100644 index d131b7026d684013f988cc9e3dcae2e5a284bc0e..0000000000000000000000000000000000000000 --- a/gf-6L-30M-i2048/config.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "architectures": [ - "BertForMaskedLM" - ], - "attention_probs_dropout_prob": 0.02, - "gradient_checkpointing": false, - "hidden_act": "relu", - "hidden_dropout_prob": 0.02, - "hidden_size": 256, - "initializer_range": 0.02, - "intermediate_size": 512, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 2048, - "model_type": "bert", - "num_attention_heads": 4, - "num_hidden_layers": 6, - "pad_token_id": 0, - "position_embedding_type": "absolute", - "transformers_version": "4.6.0", - "type_vocab_size": 2, - "use_cache": true, - "vocab_size": 25426 -} diff --git a/gf-6L-30M-i2048/model.safetensors b/gf-6L-30M-i2048/model.safetensors deleted file mode 100644 index c06bc0c9f7517d5db759187f65d27bacc76eb631..0000000000000000000000000000000000000000 --- a/gf-6L-30M-i2048/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a5e33a757431643b3697de7ef6127950cdc49e06e58d4266b3a3ab191b683f14 -size 41183536 diff --git a/gf-6L-30M-i2048/pytorch_model.bin b/gf-6L-30M-i2048/pytorch_model.bin deleted file mode 100644 index 2406c11ac74c12b711b542fe9981affccb7ec75c..0000000000000000000000000000000000000000 --- a/gf-6L-30M-i2048/pytorch_model.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8d860e2125884475dd42bc2cd9a0e60c60808a7351241e08f2154931ffc142da -size 41216562 diff --git a/gf-6L-30M-i2048/training_args.bin b/gf-6L-30M-i2048/training_args.bin deleted file mode 100644 index 3e03ccc99722f70224937e7b2e46f8faab774e23..0000000000000000000000000000000000000000 --- a/gf-6L-30M-i2048/training_args.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f0ec3459454205174c9d2e4d6c6930f6b0fbf3364fc03a6f4d99c4d3add2012b -size 2607 diff --git a/model.safetensors b/model.safetensors index 1069352219a29bed65fa8e13feb77004128174fa..f9c0c25c4a1df80ddc1455def715a9856152882c 100644 --- a/model.safetensors +++ b/model.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c -size 152012980 +oid sha256:965ceccea81953d362081ef3843560a0e4fef88d396c28017881f1e94b1246f3 +size 1265455076 diff --git a/requirements.txt b/requirements.txt index 0cb09a2593f3a727090f7cf9f7eacd36edd8ddbd..1f33548292d196172309dea765edf89df067360b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,4 @@ tdigest>=0.5.2 tensorboard>=2.15 torch>=2.0.1 tqdm>=4.65 -transformers>=4.40 +transformers==4.46 \ No newline at end of file diff --git a/setup.py b/setup.py index 6dde9eefad8c76e3d1e41ae187f2215bdbc93db5..81d947cea5933da94d1ea315dd5eee59f18c6577 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ setup( include_package_data=True, install_requires=[ "anndata", + "bitsandbytes", "datasets", "loompy", "matplotlib", diff --git a/training_args.bin b/training_args.bin index 18802f485a03e0262866d1ef7a3e4748a3b14ed3..630bab56b199325b337a9969d30167f5b73b7815 100644 --- a/training_args.bin +++ b/training_args.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d -size 4920 +oid sha256:e45150f9a4ca34cb4e91ce79f65f3d99d9d66df9f66a37517a352d291008e0b8 +size 5432