PrecollatorForGeneAndCellClassification has no attribute save_pretrained

#528
by zzbb2266 - opened

Hi there, thanks for the great pipeline and maintenance! Currently I am trying to replicate the cell_classification.ipynb from the example folder, but got an error from the PrecollatorForGeneAndCellClassification module just after finishing the first epoch:

Epoch	Training Loss	Validation Loss	Accuracy	Macro F1
0	0.133900	0.411489	0.883736	0.684715
  0%|          | 0/1 [28:44<?, ?it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 6
      1 train_valid_id_split_dict = {"attr_key": "individual",
      2                             "train": train_ids,
      3                             "eval": eval_ids}
      5 # Example 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors
----> 6 all_metrics = cc.validate(model_directory="/home/data1/geneformer/Geneformer/gf-6L-30M-i2048/",
      7                           prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset",
      8                           id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
      9                           output_directory=output_dir,
     10                           output_prefix=output_prefix,
     11                           split_id_dict=train_valid_id_split_dict)
     12                           # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/geneformer/classifier.py:791, in Classifier.validate(self, model_directory, prepared_input_data_file, id_class_dict_file, output_directory, output_prefix, split_id_dict, attr_to_split, attr_to_balance, gene_balance, max_trials, pval_threshold, save_eval_output, predict_eval, predict_trainer, n_hyperopt_trials, save_gene_split_datasets, debug_gene_split_datasets)
    789     train_data = data.select(train_indices)
    790 if n_hyperopt_trials == 0:
--> 791     trainer = self.train_classifier(
    792         model_directory,
    793         num_classes,
    794         train_data,
    795         eval_data,
    796         ksplit_output_dir,
    797         predict_trainer,
    798     )
    799 else:
    800     trainer = self.hyperopt_classifier(
    801         model_directory,
    802         num_classes,
   (...)    806         n_trials=n_hyperopt_trials,
    807     )

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/geneformer/classifier.py:1269, in Classifier.train_classifier(self, model_directory, num_classes, train_data, eval_data, output_directory, predict)
   1259 trainer = Trainer(
   1260     model=model,
   1261     args=training_args_init,
   (...)   1265     compute_metrics=cu.compute_metrics,
   1266 )
   1268 # train the classifier
-> 1269 trainer.train()
   1270 trainer.save_model(output_directory)
   1271 if predict is True:
   1272     # make eval predictions and save predictions and metrics

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:2245, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2243         hf_hub_utils.enable_progress_bars()
   2244 else:
-> 2245     return inner_training_loop(
   2246         args=args,
   2247         resume_from_checkpoint=resume_from_checkpoint,
   2248         trial=trial,
   2249         ignore_keys_for_eval=ignore_keys_for_eval,
   2250     )

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:2661, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2658     self.control.should_training_stop = True
   2660 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 2661 self._maybe_log_save_evaluate(
   2662     tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
   2663 )
   2665 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   2666     if is_torch_xla_available():
   2667         # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:3103, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate)
   3100         self.control.should_save = is_new_best_metric
   3102 if self.control.should_save:
-> 3103     self._save_checkpoint(model, trial)
   3104     self.control = self.callback_handler.on_save(self.args, self.state, self.control)

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:3200, in Trainer._save_checkpoint(self, model, trial)
   3198 run_dir = self._get_output_dir(trial=trial)
   3199 output_dir = os.path.join(run_dir, checkpoint_folder)
-> 3200 self.save_model(output_dir, _internal_call=True)
   3202 if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
   3203     best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:3902, in Trainer.save_model(self, output_dir, _internal_call)
   3899         self.model_wrapped.save_checkpoint(output_dir)
   3901 elif self.args.should_save:
-> 3902     self._save(output_dir)
   3904 # Push to the Hub when `save_model` is called by the user.
   3905 if self.args.push_to_hub and not _internal_call:

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/trainer.py:4018, in Trainer._save(self, output_dir, state_dict)
   4012 elif (
   4013     self.data_collator is not None
   4014     and hasattr(self.data_collator, "tokenizer")
   4015     and self.data_collator.tokenizer is not None
   4016 ):
   4017     logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
-> 4018     self.data_collator.tokenizer.save_pretrained(output_dir)
   4020 # Good practice: save your training arguments together with the trained model
   4021 torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

File /opt/miniconda/envs/gnenformer/lib/python3.13/site-packages/transformers/tokenization_utils_base.py:1108, in SpecialTokensMixin.__getattr__(self, key)
   1105         return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None
   1107 if key not in self.__dict__:
-> 1108     raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
   1109 else:
   1110     return super().__getattr__(key)

AttributeError: PrecollatorForGeneAndCellClassification has no attribute save_pretrained

Should I add the parameter save_pretrained manually to trainer.py or classifier.py ? Seems that the issue may due to a different version of package transformers to me. my env is:

Package                   Version
------------------------- -----------
absl-py                   2.2.2
accelerate                1.6.0
accumulation_tree         0.6.4
aiohappyeyeballs          2.6.1
aiohttp                   3.11.18
aiosignal                 1.3.2
alembic                   1.15.2
anndata                   0.11.4
annotated-types           0.7.0
array_api_compat          1.11.2
asttokens                 3.0.0
attrs                     25.3.0
certifi                   2025.4.26
charset-normalizer        3.4.1
click                     8.1.8
colorlog                  6.9.0
comm                      0.2.2
contourpy                 1.3.2
cycler                    0.12.1
datasets                  3.5.1
debugpy                   1.8.14
decorator                 5.2.1
dill                      0.3.8
docker-pycreds            0.4.0
exceptiongroup            1.2.2
executing                 2.2.0
filelock                  3.18.0
fonttools                 4.57.0
frozenlist                1.6.0
fsspec                    2025.3.0
geneformer                0.1.0
gitdb                     4.0.12
GitPython                 3.1.44
greenlet                  3.2.1
grpcio                    1.71.0
h5py                      3.13.0
huggingface-hub           0.30.2
idna                      3.10
importlib_metadata        8.6.1
ipykernel                 6.29.5
ipython                   9.2.0
ipython_pygments_lexers   1.1.1
jedi                      0.19.2
Jinja2                    3.1.6
joblib                    1.4.2
jsonschema                4.23.0
jsonschema-specifications 2025.4.1
jupyter_client            8.6.3
jupyter_core              5.7.2
kiwisolver                1.4.8
legacy-api-wrap           1.4.1
llvmlite                  0.44.0
loompy                    3.0.8
Mako                      1.3.10
Markdown                  3.8
MarkupSafe                3.0.2
matplotlib                3.10.1
matplotlib-inline         0.1.7
mpmath                    1.3.0
msgpack                   1.1.0
multidict                 6.4.3
multiprocess              0.70.16
natsort                   8.4.0
nest_asyncio              1.6.0
networkx                  3.4.2
numba                     0.61.2
numpy                     2.2.5
numpy-groupies            0.11.2
nvidia-cublas-cu12        12.6.4.1
nvidia-cuda-cupti-cu12    12.6.80
nvidia-cuda-nvrtc-cu12    12.6.77
nvidia-cuda-runtime-cu12  12.6.77
nvidia-cudnn-cu12         9.5.1.17
nvidia-cufft-cu12         11.3.0.4
nvidia-cufile-cu12        1.11.1.6
nvidia-curand-cu12        10.3.7.77
nvidia-cusolver-cu12      11.7.1.2
nvidia-cusparse-cu12      12.5.4.2
nvidia-cusparselt-cu12    0.6.3
nvidia-nccl-cu12          2.26.2
nvidia-nvjitlink-cu12     12.6.85
nvidia-nvtx-cu12          12.6.77
optuna                    4.3.0
optuna-integration        4.3.0
packaging                 25.0
pandas                    2.2.3
parso                     0.8.4
patsy                     1.0.1
peft                      0.15.2
pexpect                   4.9.0
pickleshare               0.7.5
pillow                    11.2.1
pip                       25.1
platformdirs              4.3.7
prompt_toolkit            3.0.51
propcache                 0.3.1
protobuf                  6.30.2
psutil                    7.0.0
ptyprocess                0.7.0
pure_eval                 0.2.3
pyarrow                   20.0.0
pydantic                  2.11.4
pydantic_core             2.33.2
Pygments                  2.19.1
pynndescent               0.5.13
pyparsing                 3.2.3
python-dateutil           2.9.0.post0
pytz                      2025.2
pyudorandom               1.0.0
PyYAML                    6.0.2
pyzmq                     26.4.0
ray                       2.45.0
referencing               0.36.2
regex                     2024.11.6
requests                  2.32.3
rpds-py                   0.24.0
safetensors               0.5.3
scanpy                    1.11.1
scikit-learn              1.5.2
scikit-misc               0.5.1
scipy                     1.15.2
seaborn                   0.13.2
sentry-sdk                2.27.0
session-info2             0.1.2
setproctitle              1.3.6
setuptools                80.0.1
six                       1.17.0
smmap                     5.0.2
SQLAlchemy                2.0.40
stack_data                0.6.3
statsmodels               0.14.4
sympy                     1.14.0
tdigest                   0.5.2.2
tensorboard               2.19.0
tensorboard-data-server   0.7.2
threadpoolctl             3.6.0
tokenizers                0.21.1
torch                     2.7.0
tornado                   6.4.2
tqdm                      4.67.1
traitlets                 5.14.3
transformers              4.51.3
triton                    3.3.0
typing_extensions         4.13.2
typing-inspection         0.4.0
tzdata                    2025.2
umap-learn                0.5.7
urllib3                   2.4.0
wandb                     0.19.10
wcwidth                   0.2.13
Werkzeug                  3.1.3
xxhash                    3.5.0
yarl                      1.20.0
zipp                      3.21.0

Colud I know how could I solve this issue, or jsut downgrade the package transformers? Thank you!

Update: This issue has already been solved by downgrading the version of package transformers along with another issue evaluation_strategy. FYI

hello , I am also trying to replicate the cell_classification.ipynb. could you share your environment after correction?

Hey, I'm facing this same issue with the latest version of codebase. What version of transformers are you using to correct this? I'm not able to find a version that both bypasses this error and other new errors that come up on lower package versions since the latest codebase is updated for those transformer changes.

which version should be install? I also have same issue.

Update: This issue has already been solved by downgrading the version of package transformers along with another issue evaluation_strategy. FYI

Would you tell me which version is working?

In today vesion, transformers==4.46 is ok to solve this problem

transformers==4.46 works for me thanks!

The package has been updated to resolve this error for compatibility with the updated transformers. Thank you for pointing it out.

ctheodoris changed discussion status to closed

Sign up or log in to comment