from datasets import load_dataset from collections import defaultdict import json import re import random folder = "data/" system_message = "You are a medical diagnosis classifier. Given a description of symptoms, provide ONLY the name of the most likely diagnosis. Do not include explanations, reasoning, or additional text." # Load and shuffle the dataset dataset = load_dataset("sajjadhadi/disease-diagnosis-dataset", split="train") dataset = dataset.shuffle(seed=42) # Function to clean symptom text into a standardized format def clean_symptom_text(text): pattern = r'(?:patient reported the following symptoms:|symptoms include:?)?\s*(.*?)(?:\s*(?:may indicate|based on these symptoms|what disease may the patient have\?|what is the most likely diagnosis\?).*)' match = re.search(pattern, text, re.IGNORECASE) if match: symptoms = match.group(1).strip() symptoms = re.sub(r'\s*,\s*', ', ', symptoms).rstrip(',') return f"{symptoms}" return text # Group samples by diagnosis diagnosis_to_samples = defaultdict(list) for i, sample in enumerate(dataset): diagnosis_to_samples[sample["diagnosis"]].append(i) # TODO: @Tingzhen important Select top 50 diagnoses with at least MIN_SAMPLES TARGET_SAMPLES = 300 MIN_SAMPLES = 75 top_diagnoses = [diag for diag, indices in sorted(diagnosis_to_samples.items(), key=lambda x: len(x[1]), reverse=True) if len(indices) >= MIN_SAMPLES][:MIN_SAMPLES] print(top_diagnoses) # Balance the dataset: ensure TARGET_SAMPLES per diagnosis balanced_indices = [] for diag in top_diagnoses: indices = diagnosis_to_samples[diag] if len(indices) >= TARGET_SAMPLES: # Cap at TARGET_SAMPLES selected_indices = indices[:TARGET_SAMPLES] else: # Oversample to reach TARGET_SAMPLES selected_indices = indices * (TARGET_SAMPLES // len(indices)) # Repeat full set remaining = TARGET_SAMPLES % len(indices) # Add remaining selected_indices.extend(random.sample(indices, remaining)) # Randomly sample extras balanced_indices.extend(selected_indices) # Create balanced dataset balanced_dataset = dataset.select(balanced_indices) print(f"Original dataset size: {len(dataset)}, Balanced dataset size: {len(balanced_indices)}") print(f"Number of unique diagnoses: {len(top_diagnoses)}") # Create train/test/validation splits splits = balanced_dataset.train_test_split(test_size=0.2, seed=42) test_valid_splits = splits['test'].train_test_split(test_size=0.5, seed=42) # Function to convert samples to required format and save as JSONL def save_as_jsonl(dataset, filename): with open(filename, 'w') as file: for sample in dataset: cleaned_text = clean_symptom_text(sample["text"]) conversation = { "messages": [ {"role": "system", "content": system_message}, {"role": "user", "content": cleaned_text}, {"role": "assistant", "content": sample["diagnosis"]} ] } file.write(json.dumps(conversation) + '\n') # Save datasets save_as_jsonl(splits["train"], folder + "train.jsonl") save_as_jsonl(test_valid_splits["train"], folder + "test.jsonl") save_as_jsonl(test_valid_splits["test"], folder + "valid.jsonl") # Print statistics print("Dataset splits:") print(f" Train: {len(splits['train'])}") print(f" Test: {len(test_valid_splits['train'])}") print(f" Validation: {len(test_valid_splits['test'])}") # Sample validation print("\nSample validation:") with open(folder + "train.jsonl", 'r') as file: for i, line in enumerate(file): if i >= 3: break example = json.loads(line) print(f"Example {i+1}:") print(f" System: {example['messages'][0]['content']}") print(f" User: {example['messages'][1]['content']}") print(f" Assistant: {example['messages'][2]['content']}") print() # Check class distribution in training set class_counts = defaultdict(int) with open(folder + "train.jsonl", 'r') as file: for line in file: example = json.loads(line) diagnosis = example['messages'][2]['content'] class_counts[diagnosis] += 1 print("\nClass distribution in training set:") for diagnosis, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]: print(f" {diagnosis}: {count}")