SwinUNETR Checkpoint

This is a PyTorch Lightning .ckpt checkpoint for a SwinUNETR model trained on chest CT images with TTE objective.

Usage

A quickstart script is below.

import torch
from src.networks import SwinUNETRForClassification
swin_unetr_params = {
            "img_size": (224, 224, 224),
            "in_channels": 1,
            "out_channels": 2, 
            "feature_size": 48,
            "drop_rate": 0.0,
            "attn_drop_rate": 0.0,
            "dropout_path_rate": 0.0,
            "use_checkpoint": True,
        }
model = SwinUNETRForClassification(
            swin_unetr_params=swin_unetr_params, num_classes=2
        ).to(device)
state_dict = torch.load(
        loadmodel_path, map_location=f"cuda:{torch.cuda.current_device()}"
    )
model.load_state_dict(state_dict)

For detailed instructions please follow the README in Github repo.

Downloads last month
3
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including StanfordShahLab/unet_tte