πŸ₯ Clinical Contrastive ModernBERT with [ENTITY] Token Support

This is a custom contrastive learning model specifically designed for clinical text with built-in support for the [ENTITY] token for anonymizing sensitive patient information.

🎯 Key Features

  • βœ… [ENTITY] Token Support: Anonymize patient names, IDs, locations
  • βœ… Contrastive Learning: Trained with triplet loss on clinical text
  • βœ… Clinical Domain: Optimized for medical/clinical language
  • βœ… Custom Architecture: Specialized contrastive model class
  • βœ… Attention-Masked Pooling: Proper handling of special tokens

πŸ“Š Model Details

  • Base Model: Simonlee711/Clinical_ModernBERT
  • Architecture: ContrastiveClinicalModel with triplet loss
  • Training: Triplet loss with margin=1.0
  • Vocabulary Size: 50,370 tokens
  • [ENTITY] Token ID: 50368
  • Max Sequence Length: 8192 tokens
  • Hidden Size: 768
  • Layers: 22

πŸš€ Quick Start

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

# Load model (trust_remote_code=True required for custom model)
tokenizer = AutoTokenizer.from_pretrained("nikhil061307/contrastive-learning-bert-added-token-v5")
model = AutoModel.from_pretrained("nikhil061307/contrastive-learning-bert-added-token-v5", trust_remote_code=True)

def get_clinical_embeddings(texts, max_length=256):
    """Get embeddings for clinical texts with [ENTITY] support."""
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors='pt'
    )
    
    # Use the model's custom encode method
    with torch.no_grad():
        embeddings = model.encode(inputs['input_ids'], inputs['attention_mask'])
    
    return embeddings

# Example with [ENTITY] token for anonymization
clinical_texts = [
    "Patient [ENTITY] presents with chest pain and shortness of breath.",
    "Patient [ENTITY] reports severe headache lasting 3 days.",
    "Patient [ENTITY] diagnosed with acute myocardial infarction."
]

embeddings = get_clinical_embeddings(clinical_texts)
print(f"Embeddings shape: {embeddings.shape}")

# Calculate similarities
similarity_matrix = torch.mm(embeddings, embeddings.t())
print(f"Similarity between first two texts: {similarity_matrix[0,1]:.4f}")

⚠️ Important Usage Notes

  1. Trust Remote Code: Always use trust_remote_code=True when loading
  2. Custom Architecture: This uses a specialized ContrastiveClinicalModel class
  3. [ENTITY] Token: Token ID 50368 is preserved from training
  4. L2 Normalization: Embeddings are automatically L2 normalized
  5. Attention Masking: Properly handles padding and special tokens

🎯 Training Details

  • Training Method: Triplet loss contrastive learning
  • Loss Function: Triplet loss with margin=1.0
  • Pooling Strategy: Attention-masked mean pooling
  • Dropout Rate: 0.15 (training only)
  • Normalization: L2 normalization on embeddings
  • Special Tokens: Handles [ENTITY], [PAD], [CLS], [SEP]

πŸ”’ Privacy & Compliance

This model is designed to help with healthcare data privacy by:

  • Supporting entity anonymization with [ENTITY] tokens
  • Maintaining semantic similarity despite anonymization
  • Enabling analysis of de-identified clinical text
  • Preserving medical meaning while protecting patient privacy

Note: Always ensure compliance with relevant healthcare privacy regulations (HIPAA, GDPR, etc.) when processing medical data.

Downloads last month
1,896
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Spaces using nikhil061307/contrastive-learning-bert-added-token-v5 2

Evaluation results