DocMap Leads Classifier (UK healthcare, social media)
Multi-head encoder classifier built on microsoft/deberta-v3-base
to identify:
- Intent (single label)
- Symptoms (multi-label)
- Specialties (multi-label)
Trained on DocMap leads_v1 JSONL (Supabase), using simple weak labels derived from keywords/zero-shot prompts. See label_config.json
for the canonical label spaces.
What’s in this repo
model.safetensors
: classifier heads + encoder weightslabel_config.json
: lists ofintents
,symptoms
,specialties
tokenizer.json
,tokenizer_config.json
,special_tokens_map.json
,spm.model
README.md
,.gitattributes
- (Checkpoints may be removed for smaller repo size)
Intended use
- Lead identification from public social media text in a UK healthcare context.
- Outputs: intent label and multi-label sets for symptoms/specialties.
Not a medical device. Do not use for diagnosis; for triage or marketing pre-filtering only.
Training
- Base:
microsoft/deberta-v3-base
- Epochs: 3, batch size: 16, lr: 2e-5, max_len: 256
- Split: 10% validation
- Threshold sweep on validation suggested
0.3
as default for multi-label heads.
Quick start (Python)
This model uses a lightweight custom head. Load with the snippet below (no HF widget).
import json, torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
repo_id = "YOUR_USER_OR_ORG/docmap-leads-classifier-v1" # change to your repo
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load labels
from huggingface_hub import hf_hub_download
cfg_path = hf_hub_download(repo_id, "label_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
class LeadsClassifier(nn.Module):
def __init__(self, base_model_name, num_intents, num_symptoms, num_specialties):
super().__init__()
self.encoder = AutoModel.from_pretrained(base_model_name)
hidden = self.encoder.config.hidden_size
self.dropout = nn.Dropout(0.1)
self.intent_head = nn.Linear(hidden, num_intents)
self.sym_head = nn.Linear(hidden, num_symptoms)
self.spec_head = nn.Linear(hidden, num_specialties)
def forward(self, input_ids=None, attention_mask=None):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls = self.dropout(out.last_hidden_state[:, 0, :])
return {
"intent_logits": self.intent_head(cls),
"sym_logits": self.sym_head(cls),
"spec_logits": self.spec_head(cls),
}
model = LeadsClassifier(
base_model_name="microsoft/deberta-v3-base",
num_intents=len(cfg["intents"]),
num_symptoms=len(cfg["symptoms"]),
num_specialties=len(cfg["specialties"]),
).to(device)
# Load weights
sd_path = hf_hub_download(repo_id, "model.safetensors")
model.load_state_dict(torch.load(sd_path, map_location=device, weights_only=True))
model.eval()
def predict(texts, thr=0.3):
batch = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
with torch.no_grad():
out = model(**batch)
intent = torch.softmax(out["intent_logits"], dim=-1).argmax(dim=-1).tolist()
sym_prob = torch.sigmoid(out["sym_logits"])
spec_prob = torch.sigmoid(out["spec_logits"])
intents = [cfg["intents"][i] for i in intent]
symptoms = [[cfg["symptoms"][j] for j, p in enumerate(row) if p >= thr] for row in sym_prob.tolist()]
specialties = [[cfg["specialties"][j] for j, p in enumerate(row) if p >= thr] for row in spec_prob.tolist()]
return [{"intent": i, "symptoms": s, "specialties": sp} for i, s, sp in zip(intents, symptoms, specialties)]
print(predict(["Any advice on fever? Based in Glasgow, started 3 days ago."], thr=0.3))
Inference thresholds
- Default multi-label threshold: 0.3 (from validation sweep).
- Tune per use-case; 0.5 is stricter, 0.2 more sensitive.
Limitations and risks
- Weakly supervised labels; potential label noise and leakage.
- Social media domain; may not generalize to clinical text.
- Not for medical diagnosis or emergency advice.
License
- MIT (inherits base model’s MIT license).
Citation
Please cite the base model and this repository if you use it in research or production.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support