Text Classification
Adapters
biology
diagnosis-adapter / load_data.py
naifenn's picture
Upload load_data.py with huggingface_hub
845f390 verified
raw
history blame
4.45 kB
from datasets import load_dataset
from collections import defaultdict
import json
import re
import random
folder = "data/"
system_message = "You are a medical diagnosis classifier. Given a description of symptoms, provide ONLY the name of the most likely diagnosis. Do not include explanations, reasoning, or additional text."
# Load and shuffle the dataset
dataset = load_dataset("sajjadhadi/disease-diagnosis-dataset", split="train")
dataset = dataset.shuffle(seed=42)
# Function to clean symptom text into a standardized format
def clean_symptom_text(text):
pattern = r'(?:patient reported the following symptoms:|symptoms include:?)?\s*(.*?)(?:\s*(?:may indicate|based on these symptoms|what disease may the patient have\?|what is the most likely diagnosis\?).*)'
match = re.search(pattern, text, re.IGNORECASE)
if match:
symptoms = match.group(1).strip()
symptoms = re.sub(r'\s*,\s*', ', ', symptoms).rstrip(',')
return f"{symptoms}"
return text
# Group samples by diagnosis
diagnosis_to_samples = defaultdict(list)
for i, sample in enumerate(dataset):
diagnosis_to_samples[sample["diagnosis"]].append(i)
# TODO: @Tingzhen important Select top 50 diagnoses with at least MIN_SAMPLES
TARGET_SAMPLES = 300
MIN_SAMPLES = 75
top_diagnoses = [diag for diag, indices in sorted(diagnosis_to_samples.items(),
key=lambda x: len(x[1]), reverse=True)
if len(indices) >= MIN_SAMPLES][:MIN_SAMPLES]
print(top_diagnoses)
# Balance the dataset: ensure TARGET_SAMPLES per diagnosis
balanced_indices = []
for diag in top_diagnoses:
indices = diagnosis_to_samples[diag]
if len(indices) >= TARGET_SAMPLES:
# Cap at TARGET_SAMPLES
selected_indices = indices[:TARGET_SAMPLES]
else:
# Oversample to reach TARGET_SAMPLES
selected_indices = indices * (TARGET_SAMPLES // len(indices)) # Repeat full set
remaining = TARGET_SAMPLES % len(indices) # Add remaining
selected_indices.extend(random.sample(indices, remaining)) # Randomly sample extras
balanced_indices.extend(selected_indices)
# Create balanced dataset
balanced_dataset = dataset.select(balanced_indices)
print(f"Original dataset size: {len(dataset)}, Balanced dataset size: {len(balanced_indices)}")
print(f"Number of unique diagnoses: {len(top_diagnoses)}")
# Create train/test/validation splits
splits = balanced_dataset.train_test_split(test_size=0.2, seed=42)
test_valid_splits = splits['test'].train_test_split(test_size=0.5, seed=42)
# Function to convert samples to required format and save as JSONL
def save_as_jsonl(dataset, filename):
with open(filename, 'w') as file:
for sample in dataset:
cleaned_text = clean_symptom_text(sample["text"])
conversation = {
"messages": [
{"role": "system", "content": system_message},
{"role": "user", "content": cleaned_text},
{"role": "assistant", "content": sample["diagnosis"]}
]
}
file.write(json.dumps(conversation) + '\n')
# Save datasets
save_as_jsonl(splits["train"], folder + "train.jsonl")
save_as_jsonl(test_valid_splits["train"], folder + "test.jsonl")
save_as_jsonl(test_valid_splits["test"], folder + "valid.jsonl")
# Print statistics
print("Dataset splits:")
print(f" Train: {len(splits['train'])}")
print(f" Test: {len(test_valid_splits['train'])}")
print(f" Validation: {len(test_valid_splits['test'])}")
# Sample validation
print("\nSample validation:")
with open(folder + "train.jsonl", 'r') as file:
for i, line in enumerate(file):
if i >= 3:
break
example = json.loads(line)
print(f"Example {i+1}:")
print(f" System: {example['messages'][0]['content']}")
print(f" User: {example['messages'][1]['content']}")
print(f" Assistant: {example['messages'][2]['content']}")
print()
# Check class distribution in training set
class_counts = defaultdict(int)
with open(folder + "train.jsonl", 'r') as file:
for line in file:
example = json.loads(line)
diagnosis = example['messages'][2]['content']
class_counts[diagnosis] += 1
print("\nClass distribution in training set:")
for diagnosis, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
print(f" {diagnosis}: {count}")