unet_tte / README.md
shahlab's picture
Update README.md
f015a1b verified
metadata
license: apache-2.0
tags:
  - pytorch
  - unet
  - chest-ct
  - survival-analysis
  - time-to-event
  - model-3d
model-index:
  - name: UNet-TTE
    results: []

SwinUNETR Checkpoint

This is a PyTorch Lightning .ckpt checkpoint for a SwinUNETR model trained on chest CT images with TTE objective.

Usage

A quickstart script is below.

import torch
from src.networks import SwinUNETRForClassification
swin_unetr_params = {
            "img_size": (224, 224, 224),
            "in_channels": 1,
            "out_channels": 2, 
            "feature_size": 48,
            "drop_rate": 0.0,
            "attn_drop_rate": 0.0,
            "dropout_path_rate": 0.0,
            "use_checkpoint": True,
        }
model = SwinUNETRForClassification(
            swin_unetr_params=swin_unetr_params, num_classes=2
        ).to(device)
state_dict = torch.load(
        loadmodel_path, map_location=f"cuda:{torch.cuda.current_device()}"
    )
model.load_state_dict(state_dict)

For detailed instructions please follow the README in Github repo.