ProtoPatient Model for Multi-Label Classification
Paper Reference
van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.
This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text.
arXiv:2210.08500
ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability by:
- Highlighting Relevant Tokens: Shows the most important words for each possible diagnosis.
- Retrieving Prototypical Patients: Finds training examples with similar textual patterns to provide intuitive justifications for clinicians—essentially answering, “This patient looks like that patient.”
Model Overview
Prototype-Based Classification
- The model learns prototypical vectors ((u_c)) for each diagnosis (c).
- A patient’s admission note is encoded via a PubMedBERT encoder and a linear compression layer into a diagnosis-specific representation ((v_{p,c})). This representation is generated using a label-wise attention mechanism.
- Classification scores are computed as the negative Euclidean distance between (v_{p,c}) and (u_c), which directly measures the note’s similarity to the learned prototype.
Label-Wise Attention
- For each diagnosis, a separate attention vector identifies relevant tokens in the admission note.
- This mechanism provides interpretability by indicating which tokens are most influential in driving each prediction.
Interpretable Output
- Token Highlights: The top attended words (often correlating with symptoms, risk factors, or diagnostic descriptors).
- Prototypical Patients: Examples from the training set that are closest to each prototype, representing typical presentations of a diagnosis.
Key Features and Benefits
Improved Performance on Rare Diagnoses:
Prototype-based learning has strong few-shot capabilities, which is especially beneficial for diagnoses with very few samples.Faithful Interpretations:
Quantitative evaluations (see Section 5 in the paper) indicate that the attention-based highlights are more faithful to the model’s decision process compared to post-hoc methods such as Lime, Occlusion, and gradient-based approaches.Clinical Utility:
- Provides label-wise explanations to help clinicians assess whether the predictions align with actual risk factors.
- Points to prototypical patients, allowing for comparison of new cases with typical (or atypical) presentations.
Performance Metrics
Evaluated on MIMIC-III:
- Admission Notes: 48,745
- Diagnosis Labels: 1,266
Performance (approximate):
- Macro ROC AUC: ~87–88%
- Micro ROC AUC: ~97%
- Macro PR AUC: ~18–21%
The model shows particularly strong gains for rare diagnoses (less than 50 samples) when compared with baselines like PubMedBERT alone or hierarchical attention RNNs (e.g., HAN, HA-GRU).
Additionally, the model achieves high transferability on i2b2 data (1,118 admission notes) across different clinical environments.
Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.
Repository Structure
ProtoPatient/
├── proto_model/
│ ├── proto.py
│ ├── utils.py
│ ├── metrics.py
│ └── __init__.py
├── config.json
├── setup.py
├── model.safetensors
├── tokenizer.json
├── tokenizer_config.json
├── vocab.txt
├── README.md
└── .gitattributes
How to Use the Model
1. Install Dependencies
git clone https://huggingface.co/row56/ProtoPatient
cd ProtoPatient
pip install -e . transformers torch safetensors
export TOKENIZERS_PARALLELISM=false
2. Load the Model via Hugging Face
import os
import warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
import torch
from transformers import AutoTokenizer
from proto_model.configuration_proto import ProtoConfig
from proto_model.modeling_proto import ProtoForMultiLabelClassification
cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
cfg.pretrained_model_name_or_path = "bert-base-uncased"
cfg.use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if cfg.use_cuda else "cpu")
tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
model = ProtoForMultiLabelClassification.from_pretrained(
"row56/ProtoPatient",
config=cfg,
)
model.to(device)
model.eval()
def get_proto_logits(texts):
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
batch = {
"input_ids": enc["input_ids"],
"attention_masks": enc["attention_mask"],
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
}
with torch.no_grad():
logits, _ = model.proto_module(batch)
return logits
texts = [
"Patient shows elevated heart rate and low oxygen saturation.",
"No significant findings; patient is healthy."
]
logits = get_proto_logits(texts)
print("Logits shape:", logits.shape)
print("Logits:\n", logits)
3. Training Data & Licenses
This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement.
To obtain MIMIC-III:
Visit https://physionet.org/content/mimiciii/1.4/ Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training. Sign the MIMIC-III Data Use Agreement (DUA). Download the raw notes and run the preprocessing scripts from the paper’s repository. Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license.
4. Load Precomputed Training Data for Prototype Retrieval
After you have MIMIC-III and have applied the published preprocessing, you should produce:
data/train_embeds.npy — NumPy array of shape (N, d) with per-example, per-class embeddings. data/train_texts.json — JSON array of length N of the raw admission-note strings. Place those in data/ and then:
import numpy as np
import json
train_embeds = np.load("data/train_embeds.npy")
with open("data/train_texts.json", "r") as f:
train_texts = json.load(f)
print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
5. Interpreting Outputs & Retrieving Prototypes
from sklearn.neighbors import NearestNeighbors
text = "Patient has chest pain and shortness of breath."
enc = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
batch = {
"input_ids": enc["input_ids"],
"attention_masks": enc["attention_mask"],
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
}
with torch.no_grad():
logits, metadata = model.proto_module(batch)
attn_scores = metadata["attentions"][0]
for label_id, scores in enumerate(attn_scores):
topk = sorted(zip(batch["tokens"][0], scores.tolist()),
key=lambda x: -x[1])[:5]
print(f"Label {label_id} top tokens:", topk)
proto_vecs = model.proto_module.prototype_vectors.cpu().numpy()
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
for label_id, u_c in enumerate(proto_vecs):
dist, idx = nn.kneighbors(u_c.reshape(1, -1))
print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):")
print(train_texts[idx[0][0]])
Intended Use, Limitations & Ethical Considerations
Intended Use
Research & Education:
ProtoPatient is designed primarily for academic research and educational purposes in clinical NLP.Interpretability Demonstration:
The model demonstrates how prototype-based methods can provide interpretable multi-label classification on clinical admission notes.
Limitations
Generalization:
The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations.Prototype Scope:
The current version uses a single prototype per diagnosis, though some diagnoses might have multiple typical presentations—this is an area for future improvement.Inter-diagnosis Relationships:
The model does not explicitly model relationships (e.g., conflicts or comorbidities) between different diagnoses.
Ethical & Regulatory Considerations
Not for Direct Clinical Use:
This model is not intended for direct clinical decision-making. Always consult healthcare professionals.Bias and Fairness:
Users should be aware of potential biases in the training data; rare conditions might still be misclassified.Patient Privacy:
When applying the model to real clinical data, patient privacy must be strictly maintained.
Example Interpretability Output
Based on the approach described in the paper (see Section 5 and Table 5):
Highlighted Tokens:
Tokens such as “worst headache of her life,” “vomiting,” “fever,” and “infiltrate” strongly indicate specific diagnoses.Prototypical Sample:
A snippet from a training patient with similar text segments provides a rationale for the prediction.
This interpretability output aids clinicians in understanding the model's reasoning – for example: "The system suggests intracerebral hemorrhage because the patient's note closely resembles typical cases with that diagnosis."
Recommended Citation
If you use ProtoPatient in your research, please cite:
@misc{vanaken2022this,
title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text},
author={van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander},
year={2022},
eprint={2210.08500},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
- Downloads last month
- 52
Evaluation results
- Accuracy on your_dataset_name_hereself-reported0.XX
- F1-score on your_dataset_name_hereself-reported0.XX