ProtoPatient Model for Multi-Label Classification

Paper Reference

van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.
This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text.
arXiv:2210.08500

ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability by:

  • Highlighting Relevant Tokens: Shows the most important words for each possible diagnosis.
  • Retrieving Prototypical Patients: Finds training examples with similar textual patterns to provide intuitive justifications for clinicians—essentially answering, “This patient looks like that patient.”

Model Overview

Prototype-Based Classification

  • The model learns prototypical vectors ((u_c)) for each diagnosis (c).
  • A patient’s admission note is encoded via a PubMedBERT encoder and a linear compression layer into a diagnosis-specific representation ((v_{p,c})). This representation is generated using a label-wise attention mechanism.
  • Classification scores are computed as the negative Euclidean distance between (v_{p,c}) and (u_c), which directly measures the note’s similarity to the learned prototype.

Label-Wise Attention

  • For each diagnosis, a separate attention vector identifies relevant tokens in the admission note.
  • This mechanism provides interpretability by indicating which tokens are most influential in driving each prediction.

Interpretable Output

  • Token Highlights: The top attended words (often correlating with symptoms, risk factors, or diagnostic descriptors).
  • Prototypical Patients: Examples from the training set that are closest to each prototype, representing typical presentations of a diagnosis.

Key Features and Benefits

  • Improved Performance on Rare Diagnoses:
    Prototype-based learning has strong few-shot capabilities, which is especially beneficial for diagnoses with very few samples.

  • Faithful Interpretations:
    Quantitative evaluations (see Section 5 in the paper) indicate that the attention-based highlights are more faithful to the model’s decision process compared to post-hoc methods such as Lime, Occlusion, and gradient-based approaches.

  • Clinical Utility:

    • Provides label-wise explanations to help clinicians assess whether the predictions align with actual risk factors.
    • Points to prototypical patients, allowing for comparison of new cases with typical (or atypical) presentations.

Performance Metrics

Evaluated on MIMIC-III:

  • Admission Notes: 48,745
  • Diagnosis Labels: 1,266

Performance (approximate):

  • Macro ROC AUC: ~87–88%
  • Micro ROC AUC: ~97%
  • Macro PR AUC: ~18–21%

The model shows particularly strong gains for rare diagnoses (less than 50 samples) when compared with baselines like PubMedBERT alone or hierarchical attention RNNs (e.g., HAN, HA-GRU).

Additionally, the model achieves high transferability on i2b2 data (1,118 admission notes) across different clinical environments.

Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.


Repository Structure

ProtoPatient/
├── proto_model/
│   ├── proto.py
│   ├── utils.py
│   ├── metrics.py
│   └── __init__.py
├── config.json
├── setup.py
├── model.safetensors
├── tokenizer.json
├── tokenizer_config.json
├── vocab.txt
├── README.md
└── .gitattributes

How to Use the Model

1. Install Dependencies

git clone https://huggingface.co/row56/ProtoPatient
cd ProtoPatient
pip install -e . transformers torch safetensors
export TOKENIZERS_PARALLELISM=false

2. Load the Model via Hugging Face

import os
import warnings

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from transformers import logging as hf_logging
hf_logging.set_verbosity_error() 

warnings.filterwarnings("ignore", category=UserWarning)
import torch
from transformers import AutoTokenizer
from proto_model.configuration_proto import ProtoConfig
from proto_model.modeling_proto    import ProtoForMultiLabelClassification

cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
cfg.pretrained_model_name_or_path = "bert-base-uncased"
cfg.use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if cfg.use_cuda else "cpu")

tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
model = ProtoForMultiLabelClassification.from_pretrained(
    "row56/ProtoPatient",
    config=cfg,

)
model.to(device)
model.eval()

def get_proto_logits(texts):
    enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    batch = {
        "input_ids":       enc["input_ids"],
        "attention_masks": enc["attention_mask"],
        "token_type_ids":  enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
        "tokens":          [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
    }
    with torch.no_grad():
        logits, _ = model.proto_module(batch)
    return logits

texts = [
    "Patient shows elevated heart rate and low oxygen saturation.",
    "No significant findings; patient is healthy."
]
logits = get_proto_logits(texts)
print("Logits shape:", logits.shape)
print("Logits:\n", logits)

3. Training Data & Licenses

This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement.

To obtain MIMIC-III:

Visit https://physionet.org/content/mimiciii/1.4/ Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training. Sign the MIMIC-III Data Use Agreement (DUA). Download the raw notes and run the preprocessing scripts from the paper’s repository. Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license.

4. Load Precomputed Training Data for Prototype Retrieval

After you have MIMIC-III and have applied the published preprocessing, you should produce:

data/train_embeds.npy — NumPy array of shape (N, d) with per-example, per-class embeddings. data/train_texts.json — JSON array of length N of the raw admission-note strings. Place those in data/ and then:

import numpy as np
import json

train_embeds = np.load("data/train_embeds.npy")
with open("data/train_texts.json", "r") as f:
    train_texts = json.load(f)

print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")

5. Interpreting Outputs & Retrieving Prototypes

from sklearn.neighbors import NearestNeighbors

text = "Patient has chest pain and shortness of breath."
enc  = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
batch = {
    "input_ids":       enc["input_ids"],
    "attention_masks": enc["attention_mask"],
    "token_type_ids":  enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
    "tokens":          [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
}

with torch.no_grad():
    logits, metadata = model.proto_module(batch)

attn_scores = metadata["attentions"][0]
for label_id, scores in enumerate(attn_scores):
    topk = sorted(zip(batch["tokens"][0], scores.tolist()),
                  key=lambda x: -x[1])[:5]
    print(f"Label {label_id} top tokens:", topk)

proto_vecs = model.proto_module.prototype_vectors.cpu().numpy()
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)

for label_id, u_c in enumerate(proto_vecs):
    dist, idx = nn.kneighbors(u_c.reshape(1, -1))
    print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):")
    print(train_texts[idx[0][0]])

Intended Use, Limitations & Ethical Considerations

Intended Use

  • Research & Education:
    ProtoPatient is designed primarily for academic research and educational purposes in clinical NLP.

  • Interpretability Demonstration:
    The model demonstrates how prototype-based methods can provide interpretable multi-label classification on clinical admission notes.


Limitations

  • Generalization:
    The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations.

  • Prototype Scope:
    The current version uses a single prototype per diagnosis, though some diagnoses might have multiple typical presentations—this is an area for future improvement.

  • Inter-diagnosis Relationships:
    The model does not explicitly model relationships (e.g., conflicts or comorbidities) between different diagnoses.


Ethical & Regulatory Considerations

  • Not for Direct Clinical Use:
    This model is not intended for direct clinical decision-making. Always consult healthcare professionals.

  • Bias and Fairness:
    Users should be aware of potential biases in the training data; rare conditions might still be misclassified.

  • Patient Privacy:
    When applying the model to real clinical data, patient privacy must be strictly maintained.


Example Interpretability Output

Based on the approach described in the paper (see Section 5 and Table 5):

  • Highlighted Tokens:
    Tokens such as “worst headache of her life,” “vomiting,” “fever,” and “infiltrate” strongly indicate specific diagnoses.

  • Prototypical Sample:
    A snippet from a training patient with similar text segments provides a rationale for the prediction.

This interpretability output aids clinicians in understanding the model's reasoning – for example: "The system suggests intracerebral hemorrhage because the patient's note closely resembles typical cases with that diagnosis."


Recommended Citation

If you use ProtoPatient in your research, please cite:

@misc{vanaken2022this,
  title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text},
  author={van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander},
  year={2022},
  eprint={2210.08500},
  archivePrefix={arXiv},
  primaryClass={cs.CL}
}

Downloads last month
52
Safetensors
Model size
279M params
Tensor type
I64
·
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Evaluation results