| import os | |
| import monai.networks.nets | |
| import torch | |
| from transformers import AutoConfig, AutoModel, PreTrainedModel | |
| from vista3d_config import VISTA3DConfig | |
| class VISTA3DModel(PreTrainedModel): | |
| """VISTA3D model for hugging face""" | |
| config_class = VISTA3DConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.network = monai.networks.nets.vista3d132( | |
| encoder_embed_dim=config.encoder_embed_dim, | |
| in_channels=config.input_channels, | |
| ) | |
| def forward(self, input): | |
| return self.network(input) | |
| def register_my_model(): | |
| """Utility function to register VISTA3D model so that it can be instantiate by the AutoModel function.""" | |
| AutoConfig.register("VISTA3D", VISTA3DConfig) | |
| AutoModel.register(VISTA3DConfig, VISTA3DModel) | |
| if __name__ == "__main__": | |
| FILE_PATH = os.path.dirname(__file__) | |
| MODEL_WEIGHT_PATH = os.path.join(FILE_PATH, "models/model.pt") | |
| MODEL_PATH = os.path.join(FILE_PATH, "vista3d_pretrained_model") | |
| config = VISTA3DConfig() | |
| hugging_face_model = VISTA3DModel(config) | |
| hugging_face_model.network.load_state_dict(torch.load(MODEL_WEIGHT_PATH)) | |
| hugging_face_model.save_pretrained(MODEL_PATH) | |