--- tags: - text-classification - medical - prototypical-networks - transformers library_name: transformers language: en license: mit datasets: - your_dataset_name_here model-index: - name: ProtoPatient results: - task: type: multi-label-classification dataset: name: your_dataset_name_here type: text metrics: - name: Accuracy type: accuracy value: 0.XX # Update with real value - name: F1-score type: f1 value: 0.XX # Update with real value --- # ProtoPatient Model for Multi-Label Classification ## Paper Reference **van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.** *This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text.* [arXiv:2210.08500](https://arxiv.org/abs/2210.08500) ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability by: - **Highlighting Relevant Tokens:** Shows the most important words for each possible diagnosis. - **Retrieving Prototypical Patients:** Finds training examples with similar textual patterns to provide intuitive justifications for clinicians—essentially answering, “This patient looks like that patient.” --- ## Model Overview ### Prototype-Based Classification - The model learns **prototypical vectors** (\(u_c\)) for each diagnosis \(c\). - A patient’s admission note is encoded via a PubMedBERT encoder and a linear compression layer into a diagnosis-specific representation (\(v_{p,c}\)). This representation is generated using a label-wise attention mechanism. - Classification scores are computed as the **negative Euclidean distance** between \(v_{p,c}\) and \(u_c\), which directly measures the note’s similarity to the learned prototype. ### Label-Wise Attention - For each diagnosis, a separate attention vector identifies relevant tokens in the admission note. - This mechanism provides interpretability by indicating which tokens are most influential in driving each prediction. ### Interpretable Output - **Token Highlights:** The top attended words (often correlating with symptoms, risk factors, or diagnostic descriptors). - **Prototypical Patients:** Examples from the training set that are closest to each prototype, representing typical presentations of a diagnosis. --- ## Key Features and Benefits - **Improved Performance on Rare Diagnoses:** Prototype-based learning has strong few-shot capabilities, which is especially beneficial for diagnoses with very few samples. - **Faithful Interpretations:** Quantitative evaluations (see Section 5 in the paper) indicate that the attention-based highlights are more faithful to the model’s decision process compared to post-hoc methods such as Lime, Occlusion, and gradient-based approaches. - **Clinical Utility:** - Provides label-wise explanations to help clinicians assess whether the predictions align with actual risk factors. - Points to prototypical patients, allowing for comparison of new cases with typical (or atypical) presentations. --- ## Performance Metrics Evaluated on **MIMIC-III**: - **Admission Notes:** 48,745 - **Diagnosis Labels:** 1,266 Performance (approximate): - **Macro ROC AUC:** ~87–88% - **Micro ROC AUC:** ~97% - **Macro PR AUC:** ~18–21% The model shows particularly strong gains for rare diagnoses (less than 50 samples) when compared with baselines like PubMedBERT alone or hierarchical attention RNNs (e.g., HAN, HA-GRU). Additionally, the model achieves high transferability on **i2b2** data (1,118 admission notes) across different clinical environments. *Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.* --- ## Repository Structure ```plaintext ProtoPatient/ ├── proto_model/ │ ├── proto.py │ ├── utils.py │ ├── metrics.py │ └── __init__.py ├── config.json ├── setup.py ├── model.safetensors ├── tokenizer.json ├── tokenizer_config.json ├── vocab.txt ├── README.md └── .gitattributes ``` --- ## How to Use the Model ### 1. Install Dependencies ```bash git clone https://huggingface.co/row56/ProtoPatient cd ProtoPatient pip install -e . transformers torch safetensors export TOKENIZERS_PARALLELISM=false ``` ### 2. Load the Model via Hugging Face ```python import os import warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" from transformers import logging as hf_logging hf_logging.set_verbosity_error() warnings.filterwarnings("ignore", category=UserWarning) import torch from transformers import AutoTokenizer from proto_model.configuration_proto import ProtoConfig from proto_model.modeling_proto import ProtoForMultiLabelClassification cfg = ProtoConfig.from_pretrained("row56/ProtoPatient") cfg.pretrained_model_name_or_path = "bert-base-uncased" cfg.use_cuda = torch.cuda.is_available() device = torch.device("cuda" if cfg.use_cuda else "cpu") tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path) model = ProtoForMultiLabelClassification.from_pretrained( "row56/ProtoPatient", config=cfg, ) model.to(device) model.eval() def get_proto_logits(texts): enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") batch = { "input_ids": enc["input_ids"], "attention_masks": enc["attention_mask"], "token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])), "tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]] } with torch.no_grad(): logits, _ = model.proto_module(batch) return logits texts = [ "Patient shows elevated heart rate and low oxygen saturation.", "No significant findings; patient is healthy." ] logits = get_proto_logits(texts) print("Logits shape:", logits.shape) print("Logits:\n", logits) ``` ## 3. Training Data & Licenses This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement. To obtain MIMIC-III: Visit https://physionet.org/content/mimiciii/1.4/ Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training. Sign the MIMIC-III Data Use Agreement (DUA). Download the raw notes and run the preprocessing scripts from the paper’s repository. Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license. ## 4. Load Precomputed Training Data for Prototype Retrieval After you have MIMIC-III and have applied the published preprocessing, you should produce: data/train_embeds.npy — NumPy array of shape (N, d) with per-example, per-class embeddings. data/train_texts.json — JSON array of length N of the raw admission-note strings. Place those in data/ and then: ```python import numpy as np import json train_embeds = np.load("data/train_embeds.npy") with open("data/train_texts.json", "r") as f: train_texts = json.load(f) print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}") ``` ## 5. Interpreting Outputs & Retrieving Prototypes ```python from sklearn.neighbors import NearestNeighbors text = "Patient has chest pain and shortness of breath." enc = tokenizer([text], padding=True, truncation=True, return_tensors="pt") batch = { "input_ids": enc["input_ids"], "attention_masks": enc["attention_mask"], "token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])), "tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]] } with torch.no_grad(): logits, metadata = model.proto_module(batch) attn_scores = metadata["attentions"][0] for label_id, scores in enumerate(attn_scores): topk = sorted(zip(batch["tokens"][0], scores.tolist()), key=lambda x: -x[1])[:5] print(f"Label {label_id} top tokens:", topk) proto_vecs = model.proto_module.prototype_vectors.cpu().numpy() nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds) for label_id, u_c in enumerate(proto_vecs): dist, idx = nn.kneighbors(u_c.reshape(1, -1)) print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):") print(train_texts[idx[0][0]]) ``` --- # Intended Use, Limitations & Ethical Considerations ## Intended Use - **Research & Education:** ProtoPatient is designed primarily for academic research and educational purposes in clinical NLP. - **Interpretability Demonstration:** The model demonstrates how prototype-based methods can provide interpretable multi-label classification on clinical admission notes. --- ## Limitations - **Generalization:** The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations. - **Prototype Scope:** The current version uses a single prototype per diagnosis, though some diagnoses might have multiple typical presentations—this is an area for future improvement. - **Inter-diagnosis Relationships:** The model does not explicitly model relationships (e.g., conflicts or comorbidities) between different diagnoses. --- ## Ethical & Regulatory Considerations - **Not for Direct Clinical Use:** This model is not intended for direct clinical decision-making. Always consult healthcare professionals. - **Bias and Fairness:** Users should be aware of potential biases in the training data; rare conditions might still be misclassified. - **Patient Privacy:** When applying the model to real clinical data, patient privacy must be strictly maintained. --- # Example Interpretability Output Based on the approach described in the paper (see Section 5 and Table 5): - **Highlighted Tokens:** Tokens such as “worst headache of her life,” “vomiting,” “fever,” and “infiltrate” strongly indicate specific diagnoses. - **Prototypical Sample:** A snippet from a training patient with similar text segments provides a rationale for the prediction. *This interpretability output aids clinicians in understanding the model's reasoning – for example: "The system suggests intracerebral hemorrhage because the patient's note closely resembles typical cases with that diagnosis."* --- # Recommended Citation If you use ProtoPatient in your research, please cite: ```bibtex @misc{vanaken2022this, title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text}, author={van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander}, year={2022}, eprint={2210.08500}, archivePrefix={arXiv}, primaryClass={cs.CL} }