|
|
from datasets import load_dataset |
|
|
from collections import defaultdict |
|
|
import json |
|
|
import re |
|
|
import random |
|
|
|
|
|
folder = "data/" |
|
|
system_message = "You are a medical diagnosis classifier. Given a description of symptoms, provide ONLY the name of the most likely diagnosis. Do not include explanations, reasoning, or additional text." |
|
|
|
|
|
|
|
|
dataset = load_dataset("sajjadhadi/disease-diagnosis-dataset", split="train") |
|
|
dataset = dataset.shuffle(seed=42) |
|
|
|
|
|
|
|
|
def clean_symptom_text(text): |
|
|
pattern = r'(?:patient reported the following symptoms:|symptoms include:?)?\s*(.*?)(?:\s*(?:may indicate|based on these symptoms|what disease may the patient have\?|what is the most likely diagnosis\?).*)' |
|
|
match = re.search(pattern, text, re.IGNORECASE) |
|
|
if match: |
|
|
symptoms = match.group(1).strip() |
|
|
symptoms = re.sub(r'\s*,\s*', ', ', symptoms).rstrip(',') |
|
|
return f"{symptoms}" |
|
|
return text |
|
|
|
|
|
|
|
|
diagnosis_to_samples = defaultdict(list) |
|
|
for i, sample in enumerate(dataset): |
|
|
diagnosis_to_samples[sample["diagnosis"]].append(i) |
|
|
|
|
|
|
|
|
TARGET_SAMPLES = 300 |
|
|
MIN_SAMPLES = 75 |
|
|
|
|
|
top_diagnoses = [diag for diag, indices in sorted(diagnosis_to_samples.items(), |
|
|
key=lambda x: len(x[1]), reverse=True) |
|
|
if len(indices) >= MIN_SAMPLES][:MIN_SAMPLES] |
|
|
|
|
|
print(top_diagnoses) |
|
|
|
|
|
balanced_indices = [] |
|
|
for diag in top_diagnoses: |
|
|
indices = diagnosis_to_samples[diag] |
|
|
if len(indices) >= TARGET_SAMPLES: |
|
|
|
|
|
selected_indices = indices[:TARGET_SAMPLES] |
|
|
else: |
|
|
|
|
|
selected_indices = indices * (TARGET_SAMPLES // len(indices)) |
|
|
remaining = TARGET_SAMPLES % len(indices) |
|
|
selected_indices.extend(random.sample(indices, remaining)) |
|
|
balanced_indices.extend(selected_indices) |
|
|
|
|
|
|
|
|
balanced_dataset = dataset.select(balanced_indices) |
|
|
print(f"Original dataset size: {len(dataset)}, Balanced dataset size: {len(balanced_indices)}") |
|
|
print(f"Number of unique diagnoses: {len(top_diagnoses)}") |
|
|
|
|
|
|
|
|
splits = balanced_dataset.train_test_split(test_size=0.2, seed=42) |
|
|
test_valid_splits = splits['test'].train_test_split(test_size=0.5, seed=42) |
|
|
|
|
|
|
|
|
def save_as_jsonl(dataset, filename): |
|
|
with open(filename, 'w') as file: |
|
|
for sample in dataset: |
|
|
cleaned_text = clean_symptom_text(sample["text"]) |
|
|
conversation = { |
|
|
"messages": [ |
|
|
{"role": "system", "content": system_message}, |
|
|
{"role": "user", "content": cleaned_text}, |
|
|
{"role": "assistant", "content": sample["diagnosis"]} |
|
|
] |
|
|
} |
|
|
file.write(json.dumps(conversation) + '\n') |
|
|
|
|
|
|
|
|
save_as_jsonl(splits["train"], folder + "train.jsonl") |
|
|
save_as_jsonl(test_valid_splits["train"], folder + "test.jsonl") |
|
|
save_as_jsonl(test_valid_splits["test"], folder + "valid.jsonl") |
|
|
|
|
|
|
|
|
print("Dataset splits:") |
|
|
print(f" Train: {len(splits['train'])}") |
|
|
print(f" Test: {len(test_valid_splits['train'])}") |
|
|
print(f" Validation: {len(test_valid_splits['test'])}") |
|
|
|
|
|
|
|
|
print("\nSample validation:") |
|
|
with open(folder + "train.jsonl", 'r') as file: |
|
|
for i, line in enumerate(file): |
|
|
if i >= 3: |
|
|
break |
|
|
example = json.loads(line) |
|
|
print(f"Example {i+1}:") |
|
|
print(f" System: {example['messages'][0]['content']}") |
|
|
print(f" User: {example['messages'][1]['content']}") |
|
|
print(f" Assistant: {example['messages'][2]['content']}") |
|
|
print() |
|
|
|
|
|
|
|
|
class_counts = defaultdict(int) |
|
|
with open(folder + "train.jsonl", 'r') as file: |
|
|
for line in file: |
|
|
example = json.loads(line) |
|
|
diagnosis = example['messages'][2]['content'] |
|
|
class_counts[diagnosis] += 1 |
|
|
|
|
|
print("\nClass distribution in training set:") |
|
|
for diagnosis, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]: |
|
|
print(f" {diagnosis}: {count}") |