vhdm's picture
Update README.md
bd2cb2a verified
metadata
language:
  - en
license: apache-2.0
tags:
  - clinical
  - biomedical
  - transformer
  - bert
  - fine-tuning
  - neurology
  - autoimmune
  - medical-nlp
datasets:
  - internal-ms-autoimmune
metrics:
  - accuracy
  - precision
  - recall
  - f1
  - auc
model-index:
  - name: clinicalbert-ms-autoimmune-neuro
    results:
      - task:
          name: Binary Text Classification (Autoimmune Neurology)
          type: text-classification
        dataset:
          name: Internal MS-Autoimmune Corpus
          type: custom
          split: test
          size: 493
        metrics:
          - name: Precision
            type: precision
            value: 0.9618
          - name: Recall
            type: recall
            value: 0.9207
          - name: F1
            type: f1
            value: 0.9408
          - name: ROC-AUC
            type: auc
            value: 0.9786
          - name: Accuracy
            type: accuracy
            value: 0.9615
base_model: emilyalsentzer/Bio_ClinicalBERT
library_name: transformers
pipeline_tag: text-classification
inference: false

🧠 ClinicalBERT-MS-Autoimmune-Neuro

A fine-tuned version of emilyalsentzer/Bio_ClinicalBERT
for detecting autoimmune neurological disease signals from clinical text notes.

Maintainer: Vahid Mahmoudian
Repository: vhdm/clinicalbert-ms-autoimmune-neuro
Status: Research / Proof-of-Concept — not for standalone clinical use


🔍 Model Summary

Property Value
Base model emilyalsentzer/Bio_ClinicalBERT
Task Binary text classification (autoimmune-neurological vs non-autoimmune)
Language English clinical notes
Domain Neurology / Autoimmune disorders
Dataset Internal “MS-autoimmune” corpus (split into train, valid, test)
Sequence length 512 tokens per chunk
Hardware NVIDIA H100
Mixed precision bf16
Trainer seed 42
Epochs 4
Learning rate 2 × 10⁻⁵
Optimizer adamw_torch
Batch sizes train = 24, eval = 48
Warmup ratio 0.1
Best metric recall

⚙️ Training Log (chunk-level)

Below metrics are auto-generated by Hugging Face Trainer
using the raw validation set (before note aggregation, calibration, or threshold tuning).

Step Train Loss Val Loss Accuracy Precision Recall F1 AUC
200 0.655 0.627 0.660 0.800 0.0377 0.0720 0.6259
800 0.331 0.373 0.840 0.882 0.628 0.733 0.8856
1600 0.275 0.360 0.851 0.854 0.691 0.764 0.9040
2400 0.291 0.333 0.858 0.878 0.689 0.772 0.9121
3400 0.220 0.358 0.861 0.851 0.731 0.786 0.9169
4600 0.169 0.432 0.856 0.828 0.744 0.784 0.9172

Final Trainer metrics (chunk-level):

These are raw chunk-level metrics for monitoring during training — not the final evaluation used for deployment.


🧩 Note-Level Aggregated Evaluation (final tuned results)

After post-processing with:

  • Aggregation: logit_topk (k = 3)
  • Calibration: temperature scaling (T = 1.372)
  • Threshold: tuned on validation (Fβ = 1.5 → thr ≈ 1.000)
  • Inference logic: per-note probability = mean(logit(top-k chunks))

Validation (n = 493)

Metric Value
Precision 0.9118
Recall 0.9394
F1 0.9254
Accuracy 0.9493
ROC-AUC 0.9688

Test (n = 493)

Metric Value
Precision 0.9618
Recall 0.9207
F1 0.9408
Accuracy 0.9615
ROC-AUC 0.9786

✅ Final configuration:
Aggregation = logit_topk(k=3)Temperature = 1.372Threshold ≈ 1.0


🧠 Inference Example

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, numpy as np

repo = "vhdm/clinicalbert-ms-autoimmune-neuro"
tok = AutoTokenizer.from_pretrained(repo)
model = AutoModelForSequenceClassification.from_pretrained(repo)

texts = ["Patient reports numbness in lower limbs...", "MRI shows demyelination consistent with MS."]
inputs = tok(texts, padding=True, truncation=True, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).numpy()[:, 1]

# logit_topk aggregation (k=3)
probs = np.clip(probs, 1e-6, 1-1e-6)
logits_ = np.log(probs) - np.log(1-probs)
k = 3
idx = np.argsort(logits_)[-min(k, len(logits_)):]
mean_logit = logits_[idx].mean()
note_score = 1.0 / (1.0 + np.exp(-mean_logit))

T = 1.372  # temperature
note_score_cal = 1.0 / (1.0 + np.exp(-mean_logit / T))
thr = 1.0  # tuned threshold

pred = int(note_score_cal >= thr)
print({"score": note_score_cal, "prediction": pred})

🧪 Reproducibility

TrainingArguments(
  output_dir="./runs/clinicalbert_ms",
  learning_rate=2e-5,
  per_device_train_batch_size=24,
  per_device_eval_batch_size=48,
  num_train_epochs=5,
  weight_decay=0.01,
  bf16=True,
  optim="adamw_torch",
  warmup_ratio=0.1,
  seed=42,
  evaluation_strategy="steps",
  save_strategy="steps",
  logging_steps=50,
  eval_steps=200,
  save_steps=200,
  save_total_limit=3,
  load_best_model_at_end=True,
  metric_for_best_model="recall",
  greater_is_better=True,
)