Spaces:
Sleeping
Sleeping
Upload 29 files
Browse files- .gitattributes +1 -0
- app.py +224 -0
- emotion_analysis_framework.py +546 -0
- example/example.wav +3 -0
- model_weights/binary/model_fold_0.joblib +3 -0
- model_weights/binary/model_fold_1.joblib +3 -0
- model_weights/binary/model_fold_2.joblib +3 -0
- model_weights/binary/model_fold_3.joblib +3 -0
- model_weights/binary/model_fold_4.joblib +3 -0
- model_weights/multiclass/model_fold_0.joblib +3 -0
- model_weights/multiclass/model_fold_1.joblib +3 -0
- model_weights/multiclass/model_fold_2.joblib +3 -0
- model_weights/multiclass/model_fold_3.joblib +3 -0
- model_weights/multiclass/model_fold_4.joblib +3 -0
- model_weights/regression/model_fold_0.joblib +3 -0
- model_weights/regression/model_fold_1.joblib +3 -0
- model_weights/regression/model_fold_2.joblib +3 -0
- model_weights/regression/model_fold_3.joblib +3 -0
- model_weights/regression/model_fold_4.joblib +3 -0
- models/__init__.py +17 -0
- models/inference_wav2vec.py +127 -0
- preprocessing/flattening_base.py +16 -0
- preprocessing/flattening_categorical.py +108 -0
- preprocessing/flattening_minirocket.py +155 -0
- preprocessing/flattening_statistical.py +50 -0
- utils/__init__.py +5 -0
- utils/config.py +43 -0
- utils/config_types.py +98 -0
- utils/logger.py +6 -0
- utils/tabular_transformation.py +165 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example/example.wav filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
from pathlib import Path
|
6 |
+
import tempfile
|
7 |
+
import librosa
|
8 |
+
import soundfile as sf # Add this import
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
# Suppress warnings for a cleaner interface
|
12 |
+
warnings.filterwarnings('ignore')
|
13 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
14 |
+
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
|
15 |
+
|
16 |
+
# Import your framework's components
|
17 |
+
from emotion_analysis_framework import EmotionAnalysisFramework, PatientData
|
18 |
+
|
19 |
+
# --- Initialization ---
|
20 |
+
# Initialize the framework. It will look for the 'model_weights' directory.
|
21 |
+
# Make sure this directory is in your Hugging Face Space repository.
|
22 |
+
try:
|
23 |
+
print("Initializing Emotion Analysis Framework...")
|
24 |
+
# The model_dir should point to the directory where your models are stored.
|
25 |
+
# In a Hugging Face Space, this will be relative to the app.py file.
|
26 |
+
framework = EmotionAnalysisFramework(model_dir="./model_weights", verbose=False)
|
27 |
+
print("Framework initialized successfully!")
|
28 |
+
FRAMEWORK_INITIALIZED = True
|
29 |
+
except Exception as e:
|
30 |
+
print(f"FATAL: Error initializing framework: {e}")
|
31 |
+
print("Please ensure the 'model_weights' directory is present and contains the model files.")
|
32 |
+
FRAMEWORK_INITIALIZED = False
|
33 |
+
# Define a placeholder framework to avoid crashing the app
|
34 |
+
framework = None
|
35 |
+
|
36 |
+
|
37 |
+
# --- Prediction Function ---
|
38 |
+
def analyze_emotion(audio_file, sex, race, education, age):
|
39 |
+
"""
|
40 |
+
This function is the core of the Gradio app. It takes user inputs,
|
41 |
+
runs the prediction, and returns the formatted results.
|
42 |
+
"""
|
43 |
+
if not FRAMEWORK_INITIALIZED or framework is None:
|
44 |
+
return "Error: Framework not initialized. Check logs.", "", "", "", "", ""
|
45 |
+
|
46 |
+
if audio_file is None:
|
47 |
+
return "Please upload an audio file.", "", "", "", "", ""
|
48 |
+
|
49 |
+
try:
|
50 |
+
# Gradio provides the audio file as a temporary file path
|
51 |
+
audio_path = audio_file
|
52 |
+
|
53 |
+
# Take that audio file and make sure that the sr is 16000 Hz
|
54 |
+
if isinstance(audio_path, str):
|
55 |
+
# Ensure the audio file is a valid path
|
56 |
+
if not Path(audio_path).is_file():
|
57 |
+
return "Invalid audio file path.", "", "", "", "", ""
|
58 |
+
else:
|
59 |
+
# If the audio file is not a string, it might be a temporary file object
|
60 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
|
61 |
+
temp_file.write(audio_file.read())
|
62 |
+
audio_path = temp_file.name
|
63 |
+
|
64 |
+
# Ensure the audio file is in the correct format (e.g., .wav)
|
65 |
+
if not audio_path.endswith('.wav'):
|
66 |
+
return "Please upload a valid .wav audio file.", "", "", "", "", ""
|
67 |
+
|
68 |
+
# Validate demographic inputs
|
69 |
+
if not isinstance(education, int) or not (0 <= education <= 30):
|
70 |
+
return "Education must be an integer between 0 and 30.", "", "", "", "", ""
|
71 |
+
|
72 |
+
# Check the audio sr and if its not 16000 Hz, resample it
|
73 |
+
try:
|
74 |
+
audio_data, sr = librosa.load(audio_path, sr=None)
|
75 |
+
if sr != 16000:
|
76 |
+
# print(f"Resampling audio from {sr}Hz to 16000Hz")
|
77 |
+
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
|
78 |
+
|
79 |
+
# Create a temporary file for the resampled audio
|
80 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
|
81 |
+
temp_path = temp_file.name
|
82 |
+
# Use soundfile to save the resampled audio
|
83 |
+
sf.write(temp_path, audio_data, 16000)
|
84 |
+
audio_path = temp_path
|
85 |
+
except Exception as e:
|
86 |
+
return f"Error processing audio file: {e}", "", "", "", "", ""
|
87 |
+
|
88 |
+
# Create a PatientData object with the inputs
|
89 |
+
patient_demographics = {
|
90 |
+
'sex': 0 if sex == "Female" else 1,
|
91 |
+
'race': 1 if race == "White" else 0, # Assuming 1 for White, 0 for others as an example
|
92 |
+
'educ': int(education),
|
93 |
+
'entryage': int(age)
|
94 |
+
}
|
95 |
+
|
96 |
+
patient = PatientData(
|
97 |
+
patient_id="gradio_user",
|
98 |
+
audio_path=audio_path,
|
99 |
+
demographics=patient_demographics
|
100 |
+
)
|
101 |
+
|
102 |
+
# Run prediction for all tasks
|
103 |
+
results = framework.predict(patient)
|
104 |
+
|
105 |
+
# --- Format the results for display ---
|
106 |
+
# Binary Classification Results
|
107 |
+
binary_res = results.get('binary')
|
108 |
+
if binary_res and 'error' not in binary_res.predictions:
|
109 |
+
binary_label = f"Prediction: {binary_res.predictions.get('label', 'N/A')}"
|
110 |
+
binary_confidence = f"Confidence: {binary_res.confidence:.2%}"
|
111 |
+
else:
|
112 |
+
binary_label = "Binary analysis failed."
|
113 |
+
binary_confidence = str(binary_res.predictions.get('error', '')) if binary_res else "Unknown error."
|
114 |
+
|
115 |
+
# Multiclass Classification Results
|
116 |
+
multiclass_res = results.get('multiclass')
|
117 |
+
if multiclass_res and 'error' not in multiclass_res.predictions:
|
118 |
+
multiclass_label = f"Prediction: {multiclass_res.predictions.get('label', 'N/A')}"
|
119 |
+
multiclass_confidence = f"Confidence: {multiclass_res.confidence:.2%}"
|
120 |
+
else:
|
121 |
+
multiclass_label = "Multiclass analysis failed."
|
122 |
+
multiclass_confidence = str(
|
123 |
+
multiclass_res.predictions.get('error', '')) if multiclass_res else "Unknown error."
|
124 |
+
|
125 |
+
# Regression Results
|
126 |
+
regression_res = results.get('regression')
|
127 |
+
if regression_res and 'error' not in regression_res.predictions:
|
128 |
+
mmse_score = f"Predicted MMSE Score: {regression_res.predictions.get('mmse_score', 0):.2f}"
|
129 |
+
mmse_std = f"Standard Deviation: ±{regression_res.predictions.get('std', 0):.2f}"
|
130 |
+
else:
|
131 |
+
mmse_score = "MMSE prediction failed."
|
132 |
+
mmse_std = str(regression_res.predictions.get('error', '')) if regression_res else "Unknown error."
|
133 |
+
|
134 |
+
return binary_label, binary_confidence, multiclass_label, multiclass_confidence, mmse_score, mmse_std
|
135 |
+
|
136 |
+
except Exception as e:
|
137 |
+
print(f"An error occurred during prediction: {e}")
|
138 |
+
return f"An error occurred: {e}", "", "", "", "", ""
|
139 |
+
|
140 |
+
|
141 |
+
# --- Gradio Interface Definition ---
|
142 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
143 |
+
gr.Markdown("# 🧠 Emotion and Cognitive Health Analysis from Speech")
|
144 |
+
gr.Markdown(
|
145 |
+
"Upload a patient's audio recording and provide their demographic information to get an analysis. "
|
146 |
+
"This tool provides predictions for Alzheimer's Disease (AD) vs. Healthy Control (HC), a multiclass "
|
147 |
+
"prediction (HC/MCI/AD), and an estimated MMSE score."
|
148 |
+
)
|
149 |
+
|
150 |
+
with gr.Row():
|
151 |
+
with gr.Column(scale=1):
|
152 |
+
gr.Markdown("### Inputs")
|
153 |
+
# Audio Input
|
154 |
+
audio_input = gr.Audio(type="filepath", label="Upload Patient Audio (.wav)")
|
155 |
+
|
156 |
+
# Demographics Inputs
|
157 |
+
sex_input = gr.Radio(["Female", "Male"], label="Sex")
|
158 |
+
race_input = gr.Radio(["White", "Other"], label="Race") # Adjust as needed
|
159 |
+
education_input = gr.Slider(minimum=0, maximum=30, step=1, value=16, label="Years of Education")
|
160 |
+
age_input = gr.Slider(minimum=40, maximum=100, step=1, value=65, label="Age at Entry")
|
161 |
+
|
162 |
+
analyze_btn = gr.Button("Analyze", variant="primary")
|
163 |
+
|
164 |
+
with gr.Column(scale=2):
|
165 |
+
gr.Markdown("### Analysis Results")
|
166 |
+
# Binary Classification Output
|
167 |
+
with gr.Group():
|
168 |
+
gr.Label("Binary Classification (AD vs. HC)")
|
169 |
+
binary_output_label = gr.Textbox(label="Result")
|
170 |
+
binary_output_confidence = gr.Textbox(label="Confidence")
|
171 |
+
|
172 |
+
# Multiclass Classification Output
|
173 |
+
with gr.Group():
|
174 |
+
gr.Label("Multiclass Classification (HC vs. MCI vs. AD)")
|
175 |
+
multiclass_output_label = gr.Textbox(label="Result")
|
176 |
+
multiclass_output_confidence = gr.Textbox(label="Confidence")
|
177 |
+
|
178 |
+
# Regression Output
|
179 |
+
with gr.Group():
|
180 |
+
gr.Label("MMSE Score Regression")
|
181 |
+
regression_output_score = gr.Textbox(label="MMSE Score")
|
182 |
+
regression_output_std = gr.Textbox(label="Standard Deviation")
|
183 |
+
|
184 |
+
# Connect the button to the prediction function
|
185 |
+
analyze_btn.click(
|
186 |
+
fn=analyze_emotion,
|
187 |
+
inputs=[audio_input, sex_input, race_input, education_input, age_input],
|
188 |
+
outputs=[
|
189 |
+
binary_output_label,
|
190 |
+
binary_output_confidence,
|
191 |
+
multiclass_output_label,
|
192 |
+
multiclass_output_confidence,
|
193 |
+
regression_output_score,
|
194 |
+
regression_output_std
|
195 |
+
]
|
196 |
+
)
|
197 |
+
|
198 |
+
gr.Markdown("---")
|
199 |
+
gr.Markdown(
|
200 |
+
"**Disclaimer:** This tool is for research purposes only and is not a substitute for professional medical advice, diagnosis, or treatment."
|
201 |
+
)
|
202 |
+
gr.Examples(
|
203 |
+
examples=[
|
204 |
+
["./example/example.wav", "Female", "White", 16, 58],
|
205 |
+
],
|
206 |
+
inputs=[audio_input, sex_input, race_input, education_input, age_input],
|
207 |
+
fn=analyze_emotion,
|
208 |
+
outputs=[
|
209 |
+
binary_output_label,
|
210 |
+
binary_output_confidence,
|
211 |
+
multiclass_output_label,
|
212 |
+
multiclass_output_confidence,
|
213 |
+
regression_output_score,
|
214 |
+
regression_output_std
|
215 |
+
],
|
216 |
+
cache_examples=True
|
217 |
+
)
|
218 |
+
|
219 |
+
if __name__ == "__main__":
|
220 |
+
# To run locally:
|
221 |
+
# 1. Make sure you have all dependencies from requirements.txt installed.
|
222 |
+
# 2. Place your 'model_weights' and 'example' folders in the same directory as this script.
|
223 |
+
# 3. Run 'python app.py' in your terminal.
|
224 |
+
demo.launch()
|
emotion_analysis_framework.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Emotion Analysis Framework
|
3 |
+
A framework for analyzing emotions from patient audio recordings using wav2vec2 models.
|
4 |
+
|
5 |
+
Author: Marek Sviderski
|
6 |
+
|
7 |
+
This framework supports three main tasks:
|
8 |
+
- Binary classification: Distinguishing between Alzheimer's Disease (AD) and Healthy Control (HC)
|
9 |
+
- Multiclass classification: Classifying into HC, Mild Cognitive Impairment (MCI), and AD
|
10 |
+
- Regression: Predicting the Mini-Mental State Examination (MMSE) score
|
11 |
+
|
12 |
+
This code is designed to be modular and extensible, allowing for easy integration of new models and strategies.
|
13 |
+
|
14 |
+
It uses dataclasses for structured data representation and provides methods for feature extraction, model loading, and predictions.
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Dict, List, Optional, Union, Any
|
20 |
+
from dataclasses import dataclass, field
|
21 |
+
import numpy as np
|
22 |
+
import pandas as pd
|
23 |
+
import torch
|
24 |
+
import joblib
|
25 |
+
import warnings
|
26 |
+
|
27 |
+
from utils.config import ProjectConfig
|
28 |
+
from utils.config_types import ChunkLength
|
29 |
+
from models.inference_wav2vec import Wav2VecInference
|
30 |
+
from preprocessing.flattening_statistical import StatisticalFlattening
|
31 |
+
from preprocessing.flattening_categorical import CategoricalFlattening
|
32 |
+
from preprocessing.flattening_minirocket import MiniRocketFlattening
|
33 |
+
|
34 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='xgboost')
|
35 |
+
warnings.filterwarnings('ignore', message='Some weights of the model checkpoint')
|
36 |
+
warnings.filterwarnings('ignore', message='Some weights of Wav2Vec2ForSequenceClassification')
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class PatientData:
|
42 |
+
"""Data class for patient information"""
|
43 |
+
patient_id: str
|
44 |
+
audio_path: str
|
45 |
+
demographics: Dict[str, Any] = field(default_factory=dict)
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class PredictionResult:
|
50 |
+
"""Data class for prediction results"""
|
51 |
+
patient_id: str
|
52 |
+
task: str
|
53 |
+
predictions: Dict[str, Any]
|
54 |
+
probabilities: Optional[Dict[str, Any]] = None
|
55 |
+
confidence: Optional[float] = None
|
56 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
57 |
+
|
58 |
+
def summary(self) -> str:
|
59 |
+
"""Return a clean summary of the prediction"""
|
60 |
+
if self.task == 'binary':
|
61 |
+
return f"Binary: {self.predictions['label']} (confidence: {self.confidence:.2f})"
|
62 |
+
elif self.task == 'multiclass':
|
63 |
+
return f"Multiclass: {self.predictions['label']} (confidence: {self.confidence:.2f})"
|
64 |
+
elif self.task == 'regression':
|
65 |
+
return f"MMSE Score: {self.predictions['mmse_score']:.1f} ± {self.predictions['std']:.1f}"
|
66 |
+
else:
|
67 |
+
return str(self.predictions)
|
68 |
+
|
69 |
+
|
70 |
+
class EmotionAnalysisFramework:
|
71 |
+
"""
|
72 |
+
End-to-end framework for emotion analysis from patient recordings.
|
73 |
+
|
74 |
+
This framework provides three prediction tasks:
|
75 |
+
- Binary classification: AD vs Healthy Control (HC)
|
76 |
+
- Multiclass classification: HC vs MCI vs AD
|
77 |
+
- Regression: MMSE score prediction
|
78 |
+
|
79 |
+
Args:
|
80 |
+
config_path: Optional path to custom configuration
|
81 |
+
model_dir: Path to directory containing model weights
|
82 |
+
verbose: Whether to print detailed progress information
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, config_path: Optional[str] = None,
|
86 |
+
model_dir: Optional[str] = None,
|
87 |
+
verbose: bool = False):
|
88 |
+
self.config = ProjectConfig(config_path) if config_path else ProjectConfig()
|
89 |
+
self.model_dir = model_dir
|
90 |
+
self.verbose = verbose
|
91 |
+
self.wav2vec_model = None
|
92 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
93 |
+
self.models = {
|
94 |
+
'binary': {},
|
95 |
+
'multiclass': {},
|
96 |
+
'regression': {}
|
97 |
+
}
|
98 |
+
self.strategies = {}
|
99 |
+
self._initialize_strategies()
|
100 |
+
self._load_models()
|
101 |
+
|
102 |
+
def _log(self, message: str):
|
103 |
+
"""Print message only if verbose mode is enabled"""
|
104 |
+
if self.verbose:
|
105 |
+
print(message)
|
106 |
+
|
107 |
+
def _initialize_strategies(self):
|
108 |
+
"""Initialize all flattening strategies"""
|
109 |
+
self.strategies = {
|
110 |
+
'statistical': StatisticalFlattening(),
|
111 |
+
'categorical': CategoricalFlattening(),
|
112 |
+
'minirocket': MiniRocketFlattening()
|
113 |
+
}
|
114 |
+
|
115 |
+
def _load_models(self):
|
116 |
+
"""Load all trained models"""
|
117 |
+
if not self.model_dir:
|
118 |
+
# Try to find models in package directory
|
119 |
+
package_dir = os.path.dirname(os.path.abspath(__file__))
|
120 |
+
self.model_dir = os.path.join(package_dir, "model_weights")
|
121 |
+
|
122 |
+
if not os.path.exists(self.model_dir):
|
123 |
+
raise ValueError(f"Model directory not found: {self.model_dir}")
|
124 |
+
|
125 |
+
# Load models for each task
|
126 |
+
self._load_task_models('binary', os.path.join(self.model_dir, "binary"))
|
127 |
+
self._load_task_models('multiclass', os.path.join(self.model_dir, "multiclass"))
|
128 |
+
self._load_task_models('regression', os.path.join(self.model_dir, "regression"))
|
129 |
+
|
130 |
+
# Verify models were loaded
|
131 |
+
for task in ['binary', 'multiclass', 'regression']:
|
132 |
+
if not self.models[task]:
|
133 |
+
raise ValueError(f"No {task} models found in {self.model_dir}")
|
134 |
+
|
135 |
+
def _load_task_models(self, task: str, path: str):
|
136 |
+
"""Load models for a specific task"""
|
137 |
+
if not os.path.exists(path):
|
138 |
+
self._log(f"Warning: {task} model path not found: {path}")
|
139 |
+
return
|
140 |
+
|
141 |
+
model_type = 'simple' if task in ['binary', 'regression'] else 'fusion'
|
142 |
+
self.models[task][model_type] = {}
|
143 |
+
|
144 |
+
model_files = [f for f in os.listdir(path)
|
145 |
+
if f.startswith('model_fold_') and f.endswith('.joblib')]
|
146 |
+
|
147 |
+
for file in model_files:
|
148 |
+
fold_num = file.split('_')[-1].replace('.joblib', '')
|
149 |
+
model_path = os.path.join(path, file)
|
150 |
+
try:
|
151 |
+
self.models[task][model_type][fold_num] = joblib.load(model_path)
|
152 |
+
self._log(f"Loaded {task} {model_type} model fold {fold_num}")
|
153 |
+
except Exception as e:
|
154 |
+
self._log(f"Error loading model {model_path}: {e}")
|
155 |
+
|
156 |
+
def _extract_wav2vec_features(self, audio_path: str, chunk_length: ChunkLength) -> pd.DataFrame:
|
157 |
+
"""Extract wav2vec features from audio file"""
|
158 |
+
if self.wav2vec_model is None:
|
159 |
+
self.wav2vec_model = Wav2VecInference(self.config, verbose=self.verbose)
|
160 |
+
self.wav2vec_model.load_model()
|
161 |
+
|
162 |
+
chunk_config = self.config.get_chunk_params(chunk_length)
|
163 |
+
|
164 |
+
emotions_over_time = self.wav2vec_model.analyze_emotions_over_time(
|
165 |
+
audio_path,
|
166 |
+
segment_duration=chunk_config.segment_duration,
|
167 |
+
overlap_duration=chunk_config.overlap_duration
|
168 |
+
)
|
169 |
+
|
170 |
+
rows = []
|
171 |
+
for start, end, emotions in emotions_over_time:
|
172 |
+
rows.append({
|
173 |
+
'filename': str(Path(audio_path).stem),
|
174 |
+
'start': start,
|
175 |
+
'end': end,
|
176 |
+
**emotions
|
177 |
+
})
|
178 |
+
|
179 |
+
return pd.DataFrame(rows)
|
180 |
+
|
181 |
+
def _prepare_features(self, patient_data: PatientData, task: str) -> Dict[str, pd.DataFrame]:
|
182 |
+
"""Prepare features for a specific task - FIXED VERSION"""
|
183 |
+
prepared_data = {}
|
184 |
+
|
185 |
+
if task == 'binary':
|
186 |
+
# Binary uses statistical flattening with 3.5s chunks
|
187 |
+
df_3_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_3_5)
|
188 |
+
flattened = self.strategies['statistical'].flatten_dataframe(df_3_5)
|
189 |
+
|
190 |
+
# Add demographics more efficiently
|
191 |
+
if patient_data.demographics:
|
192 |
+
# Create a copy and add all demographics at once
|
193 |
+
demo_df = pd.DataFrame([patient_data.demographics] * len(flattened))
|
194 |
+
flattened = pd.concat([flattened, demo_df], axis=1)
|
195 |
+
|
196 |
+
prepared_data['simple'] = flattened
|
197 |
+
|
198 |
+
elif task == 'multiclass':
|
199 |
+
# Multiclass uses categorical flattening with different chunk lengths
|
200 |
+
df_1_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_1_5)
|
201 |
+
df_4_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_4_5)
|
202 |
+
|
203 |
+
flattened_1_5 = self.strategies['categorical'].flatten_dataframe(df_1_5)
|
204 |
+
flattened_4_5 = self.strategies['categorical'].flatten_dataframe(df_4_5)
|
205 |
+
|
206 |
+
if patient_data.demographics:
|
207 |
+
# Add demographics efficiently
|
208 |
+
demo_df = pd.DataFrame([patient_data.demographics])
|
209 |
+
flattened_1_5 = pd.concat([flattened_1_5, demo_df], axis=1)
|
210 |
+
flattened_4_5 = pd.concat([flattened_4_5, demo_df], axis=1)
|
211 |
+
|
212 |
+
prepared_data['chunk1'] = flattened_1_5
|
213 |
+
prepared_data['chunk2'] = flattened_4_5
|
214 |
+
|
215 |
+
elif task == 'regression':
|
216 |
+
# Regression uses minirocket for 1.5s and 3.5s, categorical for 4.5s
|
217 |
+
df_1_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_1_5)
|
218 |
+
df_3_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_3_5)
|
219 |
+
df_4_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_4_5)
|
220 |
+
|
221 |
+
flattened_1_5 = self.strategies['minirocket'].flatten_dataframe(df_1_5)
|
222 |
+
flattened_3_5 = self.strategies['minirocket'].flatten_dataframe(df_3_5)
|
223 |
+
flattened_4_5 = self.strategies['categorical'].flatten_dataframe(df_4_5)
|
224 |
+
|
225 |
+
if patient_data.demographics:
|
226 |
+
# Add demographics efficiently to 4_5 only
|
227 |
+
demo_df = pd.DataFrame([patient_data.demographics])
|
228 |
+
flattened_4_5 = pd.concat([flattened_4_5, demo_df], axis=1)
|
229 |
+
|
230 |
+
prepared_data['chunk_1_5'] = flattened_1_5
|
231 |
+
prepared_data['chunk_3_5'] = flattened_3_5
|
232 |
+
prepared_data['chunk_4_5_demo'] = flattened_4_5
|
233 |
+
|
234 |
+
return prepared_data
|
235 |
+
|
236 |
+
def _predict_binary(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
|
237 |
+
"""Make binary classification predictions - FIXED"""
|
238 |
+
model_type = 'simple'
|
239 |
+
|
240 |
+
if model_type not in self.models['binary'] or not self.models['binary'][model_type]:
|
241 |
+
raise ValueError(f"No binary models loaded")
|
242 |
+
|
243 |
+
fold_predictions = []
|
244 |
+
fold_probabilities = []
|
245 |
+
|
246 |
+
# Get the patient ID from features
|
247 |
+
patient_id = 'unknown'
|
248 |
+
if 'filename' in features[model_type].columns:
|
249 |
+
patient_id = features[model_type]['filename'].iloc[0]
|
250 |
+
|
251 |
+
# Get predictions from all folds
|
252 |
+
for fold_num, model in self.models['binary'][model_type].items():
|
253 |
+
try:
|
254 |
+
X = features[model_type]
|
255 |
+
|
256 |
+
# Ensure correct feature order
|
257 |
+
if hasattr(model, 'feature_names_in_'):
|
258 |
+
# Only use features that the model was trained on
|
259 |
+
model_features = [f for f in model.feature_names_in_ if f in X.columns]
|
260 |
+
X = X[model_features]
|
261 |
+
|
262 |
+
pred = model.predict(X)
|
263 |
+
pred_proba = model.predict_proba(X)
|
264 |
+
|
265 |
+
fold_predictions.append(pred[0])
|
266 |
+
fold_probabilities.append(pred_proba[0])
|
267 |
+
except Exception as e:
|
268 |
+
self._log(f"Error in fold {fold_num}: {e}")
|
269 |
+
continue
|
270 |
+
|
271 |
+
if not fold_predictions:
|
272 |
+
raise ValueError("No successful predictions from any fold")
|
273 |
+
|
274 |
+
# Aggregate predictions (majority vote)
|
275 |
+
final_prediction = int(np.round(np.mean(fold_predictions)))
|
276 |
+
mean_probabilities = np.mean(fold_probabilities, axis=0)
|
277 |
+
|
278 |
+
return PredictionResult(
|
279 |
+
patient_id=patient_id,
|
280 |
+
task='binary',
|
281 |
+
predictions={'class': final_prediction, 'label': 'AD' if final_prediction else 'HC'},
|
282 |
+
probabilities={'HC': float(mean_probabilities[0]), 'AD': float(mean_probabilities[1])},
|
283 |
+
confidence=float(np.max(mean_probabilities)),
|
284 |
+
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
|
285 |
+
)
|
286 |
+
|
287 |
+
def _predict_multiclass(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
|
288 |
+
"""Make multiclass classification predictions - FIXED"""
|
289 |
+
if 'fusion' not in self.models['multiclass'] or not self.models['multiclass']['fusion']:
|
290 |
+
raise ValueError("No multiclass fusion model loaded")
|
291 |
+
|
292 |
+
fold_predictions = []
|
293 |
+
fold_probabilities = []
|
294 |
+
class_labels = ['HC', 'MCI', 'AD']
|
295 |
+
|
296 |
+
# Get patient ID
|
297 |
+
patient_id = 'unknown'
|
298 |
+
if 'filename' in features['chunk1'].columns:
|
299 |
+
patient_id = features['chunk1']['filename'].iloc[0]
|
300 |
+
|
301 |
+
for fold_num, model_pack in self.models['multiclass']['fusion'].items():
|
302 |
+
try:
|
303 |
+
# Prepare features for each model
|
304 |
+
model1 = model_pack['chunk1']
|
305 |
+
model2 = model_pack['chunk2']
|
306 |
+
|
307 |
+
# Get features that the models were trained on
|
308 |
+
X_chunk1 = features['chunk1']
|
309 |
+
X_chunk2 = features['chunk2']
|
310 |
+
|
311 |
+
# Ensure we have the right features
|
312 |
+
if hasattr(model1, 'feature_names_in_'):
|
313 |
+
model1_features = [f for f in model1.feature_names_in_ if f in X_chunk1.columns]
|
314 |
+
X_chunk1 = X_chunk1[model1_features]
|
315 |
+
|
316 |
+
if hasattr(model2, 'feature_names_in_'):
|
317 |
+
model2_features = [f for f in model2.feature_names_in_ if f in X_chunk2.columns]
|
318 |
+
X_chunk2 = X_chunk2[model2_features]
|
319 |
+
|
320 |
+
pred_proba_1 = model1.predict_proba(X_chunk1)
|
321 |
+
pred_proba_2 = model2.predict_proba(X_chunk2)
|
322 |
+
|
323 |
+
# Apply fusion weights
|
324 |
+
weights = model_pack.get('weights', [0.5, 0.5])
|
325 |
+
fusion_proba = weights[0] * pred_proba_1 + weights[1] * pred_proba_2
|
326 |
+
|
327 |
+
pred = np.argmax(fusion_proba, axis=1)
|
328 |
+
|
329 |
+
fold_predictions.append(pred[0])
|
330 |
+
fold_probabilities.append(fusion_proba[0])
|
331 |
+
except Exception as e:
|
332 |
+
self._log(f"Error in multiclass fold {fold_num}: {e}")
|
333 |
+
continue
|
334 |
+
|
335 |
+
if not fold_predictions:
|
336 |
+
raise ValueError("No successful multiclass predictions from any fold")
|
337 |
+
|
338 |
+
# Aggregate predictions
|
339 |
+
final_prediction = int(np.round(np.mean(fold_predictions)))
|
340 |
+
mean_probabilities = np.mean(fold_probabilities, axis=0)
|
341 |
+
|
342 |
+
prob_dict = {label: float(prob) for label, prob in zip(class_labels, mean_probabilities)}
|
343 |
+
|
344 |
+
return PredictionResult(
|
345 |
+
patient_id=patient_id,
|
346 |
+
task='multiclass',
|
347 |
+
predictions={'class': final_prediction, 'label': class_labels[final_prediction]},
|
348 |
+
probabilities=prob_dict,
|
349 |
+
confidence=float(np.max(mean_probabilities)),
|
350 |
+
metadata={'num_folds': len(fold_predictions)}
|
351 |
+
)
|
352 |
+
|
353 |
+
def _predict_regression(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
|
354 |
+
"""Make regression predictions - FIXED"""
|
355 |
+
model_type = 'simple'
|
356 |
+
|
357 |
+
if model_type not in self.models['regression'] or not self.models['regression'][model_type]:
|
358 |
+
raise ValueError("No regression models loaded")
|
359 |
+
|
360 |
+
fold_predictions = []
|
361 |
+
|
362 |
+
# Get patient ID
|
363 |
+
patient_id = 'unknown'
|
364 |
+
if 'filename' in features['chunk_1_5'].columns:
|
365 |
+
patient_id = features['chunk_1_5']['filename'].iloc[0]
|
366 |
+
|
367 |
+
for fold_num, model_pack in self.models['regression'][model_type].items():
|
368 |
+
try:
|
369 |
+
# Simple fusion prediction
|
370 |
+
models = model_pack['models']
|
371 |
+
weights = model_pack['weights']
|
372 |
+
|
373 |
+
# Get predictions from each model with proper feature selection
|
374 |
+
model1 = models["ridge_1_5_minirocket"]
|
375 |
+
X1 = features['chunk_1_5']
|
376 |
+
if hasattr(model1, 'feature_names_in_'):
|
377 |
+
model1_features = [f for f in model1.feature_names_in_ if f in X1.columns]
|
378 |
+
X1 = X1[model1_features]
|
379 |
+
pred1 = model1.predict(X1)
|
380 |
+
|
381 |
+
model2 = models["ridge_3_5_minirocket"]
|
382 |
+
X2 = features['chunk_3_5']
|
383 |
+
if hasattr(model2, 'feature_names_in_'):
|
384 |
+
model2_features = [f for f in model2.feature_names_in_ if f in X2.columns]
|
385 |
+
X2 = X2[model2_features]
|
386 |
+
pred2 = model2.predict(X2)
|
387 |
+
|
388 |
+
model3 = models["ridge_4_5_categorical"]
|
389 |
+
X3 = features['chunk_4_5_demo']
|
390 |
+
if hasattr(model3, 'feature_names_in_'):
|
391 |
+
model3_features = [f for f in model3.feature_names_in_ if f in X3.columns]
|
392 |
+
X3 = X3[model3_features]
|
393 |
+
pred3 = model3.predict(X3)
|
394 |
+
|
395 |
+
final_pred = weights[0] * pred1 + weights[1] * pred2 + weights[2] * pred3
|
396 |
+
fold_predictions.append(final_pred[0])
|
397 |
+
|
398 |
+
except Exception as e:
|
399 |
+
self._log(f"Error in regression fold {fold_num}: {e}")
|
400 |
+
continue
|
401 |
+
|
402 |
+
if not fold_predictions:
|
403 |
+
raise ValueError("No successful regression predictions from any fold")
|
404 |
+
|
405 |
+
# Aggregate predictions
|
406 |
+
final_prediction = float(np.mean(fold_predictions))
|
407 |
+
std_prediction = float(np.std(fold_predictions))
|
408 |
+
|
409 |
+
return PredictionResult(
|
410 |
+
patient_id=patient_id,
|
411 |
+
task='regression',
|
412 |
+
predictions={'mmse_score': final_prediction, 'std': std_prediction},
|
413 |
+
confidence=1.0 / (1.0 + std_prediction),
|
414 |
+
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
|
415 |
+
)
|
416 |
+
|
417 |
+
def _predict_regression(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
|
418 |
+
"""Make regression predictions - FIXED"""
|
419 |
+
model_type = 'simple'
|
420 |
+
|
421 |
+
if model_type not in self.models['regression'] or not self.models['regression'][model_type]:
|
422 |
+
raise ValueError("No regression models loaded")
|
423 |
+
|
424 |
+
fold_predictions = []
|
425 |
+
|
426 |
+
# Get patient ID
|
427 |
+
patient_id = 'unknown'
|
428 |
+
if 'filename' in features['chunk_1_5'].columns:
|
429 |
+
patient_id = features['chunk_1_5']['filename'].iloc[0]
|
430 |
+
|
431 |
+
for fold_num, model_pack in self.models['regression'][model_type].items():
|
432 |
+
try:
|
433 |
+
# Simple fusion prediction
|
434 |
+
models = model_pack['models']
|
435 |
+
weights = model_pack['weights']
|
436 |
+
|
437 |
+
# Get predictions from each model with proper feature selection
|
438 |
+
model1 = models["ridge_1_5_minirocket"]
|
439 |
+
X1 = features['chunk_1_5']
|
440 |
+
if hasattr(model1, 'feature_names_in_'):
|
441 |
+
model1_features = [f for f in model1.feature_names_in_ if f in X1.columns]
|
442 |
+
X1 = X1[model1_features]
|
443 |
+
pred1 = model1.predict(X1)
|
444 |
+
|
445 |
+
model2 = models["ridge_3_5_minirocket"]
|
446 |
+
X2 = features['chunk_3_5']
|
447 |
+
if hasattr(model2, 'feature_names_in_'):
|
448 |
+
model2_features = [f for f in model2.feature_names_in_ if f in X2.columns]
|
449 |
+
X2 = X2[model2_features]
|
450 |
+
pred2 = model2.predict(X2)
|
451 |
+
|
452 |
+
model3 = models["ridge_4_5_categorical"]
|
453 |
+
X3 = features['chunk_4_5_demo']
|
454 |
+
if hasattr(model3, 'feature_names_in_'):
|
455 |
+
model3_features = [f for f in model3.feature_names_in_ if f in X3.columns]
|
456 |
+
X3 = X3[model3_features]
|
457 |
+
pred3 = model3.predict(X3)
|
458 |
+
|
459 |
+
final_pred = weights[0] * pred1 + weights[1] * pred2 + weights[2] * pred3
|
460 |
+
fold_predictions.append(final_pred[0])
|
461 |
+
|
462 |
+
except Exception as e:
|
463 |
+
self._log(f"Error in regression fold {fold_num}: {e}")
|
464 |
+
continue
|
465 |
+
|
466 |
+
if not fold_predictions:
|
467 |
+
raise ValueError("No successful regression predictions from any fold")
|
468 |
+
|
469 |
+
# Aggregate predictions
|
470 |
+
final_prediction = float(np.mean(fold_predictions))
|
471 |
+
std_prediction = float(np.std(fold_predictions))
|
472 |
+
|
473 |
+
return PredictionResult(
|
474 |
+
patient_id=patient_id,
|
475 |
+
task='regression',
|
476 |
+
predictions={'mmse_score': final_prediction, 'std': std_prediction},
|
477 |
+
confidence=1.0 / (1.0 + std_prediction),
|
478 |
+
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
|
479 |
+
)
|
480 |
+
|
481 |
+
def predict(self, patient_data: Union[PatientData, List[PatientData]],
|
482 |
+
task: Optional[str] = None) -> Union[
|
483 |
+
PredictionResult, List[PredictionResult], Dict[str, PredictionResult]]:
|
484 |
+
"""
|
485 |
+
Make predictions for one or more patients.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
patient_data: Single PatientData or list of PatientData objects
|
489 |
+
task: Specific task ('binary', 'multiclass', 'regression') or None for all
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
Single PredictionResult, list of results, or dict of results by task
|
493 |
+
"""
|
494 |
+
# Handle single patient
|
495 |
+
if isinstance(patient_data, PatientData):
|
496 |
+
patient_data = [patient_data]
|
497 |
+
single_patient = True
|
498 |
+
else:
|
499 |
+
single_patient = False
|
500 |
+
|
501 |
+
results = []
|
502 |
+
|
503 |
+
for patient in patient_data:
|
504 |
+
patient_results = {}
|
505 |
+
|
506 |
+
tasks_to_run = [task] if task else ['binary', 'multiclass', 'regression']
|
507 |
+
|
508 |
+
for current_task in tasks_to_run:
|
509 |
+
try:
|
510 |
+
# Prepare features for the task
|
511 |
+
features = self._prepare_features(patient, current_task)
|
512 |
+
|
513 |
+
# Make predictions based on task
|
514 |
+
if current_task == 'binary':
|
515 |
+
result = self._predict_binary(features)
|
516 |
+
elif current_task == 'multiclass':
|
517 |
+
result = self._predict_multiclass(features)
|
518 |
+
elif current_task == 'regression':
|
519 |
+
result = self._predict_regression(features)
|
520 |
+
else:
|
521 |
+
raise ValueError(f"Unknown task: {current_task}")
|
522 |
+
|
523 |
+
# Fix: use patient from the loop, not patient_data
|
524 |
+
result.patient_id = patient.patient_id
|
525 |
+
patient_results[current_task] = result
|
526 |
+
|
527 |
+
except Exception as e:
|
528 |
+
self._log(f"Error predicting {current_task} for patient {patient.patient_id}: {e}")
|
529 |
+
patient_results[current_task] = PredictionResult(
|
530 |
+
patient_id=patient.patient_id,
|
531 |
+
task=current_task,
|
532 |
+
predictions={'error': str(e)},
|
533 |
+
metadata={'status': 'failed'}
|
534 |
+
)
|
535 |
+
|
536 |
+
# Return appropriate format
|
537 |
+
if task: # Single task
|
538 |
+
results.append(patient_results[task])
|
539 |
+
else: # All tasks
|
540 |
+
results.append(patient_results)
|
541 |
+
|
542 |
+
# Format return based on input
|
543 |
+
if single_patient:
|
544 |
+
return results[0]
|
545 |
+
else:
|
546 |
+
return results
|
example/example.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54ff5dd4723f32d833b0631cccb64b1422cb1bdf629c1dc00398a7f65b196c6e
|
3 |
+
size 57602320
|
model_weights/binary/model_fold_0.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:edb9a91485fb5d7c7fe6fe9796b3acdf03dbef1842b97dcfec15e34145b816e1
|
3 |
+
size 4372336
|
model_weights/binary/model_fold_1.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed1aea5b008ae5f4ec22ecb4123332a54f6fa9ac94777a6c0c17f4934740c3eb
|
3 |
+
size 4408720
|
model_weights/binary/model_fold_2.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a9cd344db067466560b22bbe6a0b3c919b1ab3f62af501aca734049d2242397b
|
3 |
+
size 4328032
|
model_weights/binary/model_fold_3.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f77fff9a656acc76215308264e06bcefe66feed764973dc6f266c804dd4cb8f
|
3 |
+
size 4408992
|
model_weights/binary/model_fold_4.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad60480b9c54c419f32d0fc2d2ee2146cf08c4930a9471394f09145e581825aa
|
3 |
+
size 4316144
|
model_weights/multiclass/model_fold_0.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:78d1f25bdf7c00f3269f9e8617bb00e8b3c1a40807809f9e4dcbffd09cc589ff
|
3 |
+
size 3025389
|
model_weights/multiclass/model_fold_1.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18ebd09065c2064f4badb15b1613049a65bb1adff602344d38f21766166b7c29
|
3 |
+
size 3031085
|
model_weights/multiclass/model_fold_2.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9bc18f007aafa81961d9bb7983d64cffe74e97920179d5f5d3d5252855daa552
|
3 |
+
size 3036701
|
model_weights/multiclass/model_fold_3.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e39448ac78a7bb0ac9033be9fb4110ea250575b0e2138b17f20740cf7977e31
|
3 |
+
size 3018797
|
model_weights/multiclass/model_fold_4.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ebd6ebea08c3a8d1163298f07411c30fcaf2b66bda47d04366e1efa3bb1a49a6
|
3 |
+
size 3020637
|
model_weights/regression/model_fold_0.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:adb800cd7a00422815621e359a2a268da5269598ad078ddb68c33094eea5d1bb
|
3 |
+
size 101857
|
model_weights/regression/model_fold_1.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a93d25b3c8159d09b7d63cf95016fbe0272469c08a08bb85deacbde6573b7fb
|
3 |
+
size 101857
|
model_weights/regression/model_fold_2.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:07d48e469318ea8db8a2436208fa96f722706b66acf05766f23db5e5726e2877
|
3 |
+
size 101857
|
model_weights/regression/model_fold_3.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c3431945f467ab8709980eb4fcd3190acba70a2b7c244405cb264b24c0c48f9
|
3 |
+
size 101857
|
model_weights/regression/model_fold_4.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:049cfb0e17685c959697fc49ab4bb077a549b7b45a6425f41a6b3c8c02349fe8
|
3 |
+
size 101857
|
models/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/__init__.py
|
2 |
+
from .inference_wav2vec import Wav2VecInference
|
3 |
+
|
4 |
+
__all__ = ["Wav2VecInference"]
|
5 |
+
|
6 |
+
# preprocessing/__init__.py
|
7 |
+
from preprocessing.flattening_base import BaseFlattening
|
8 |
+
from preprocessing.flattening_categorical import CategoricalFlattening
|
9 |
+
from preprocessing.flattening_minirocket import MiniRocketFlattening
|
10 |
+
from preprocessing.flattening_statistical import StatisticalFlattening
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"BaseFlattening",
|
14 |
+
"CategoricalFlattening",
|
15 |
+
"MiniRocketFlattening",
|
16 |
+
"StatisticalFlattening"
|
17 |
+
]
|
models/inference_wav2vec.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
from typing import List
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
|
9 |
+
|
10 |
+
from utils.config_types import ChunkLength
|
11 |
+
|
12 |
+
# Suppress specific warnings
|
13 |
+
warnings.filterwarnings('ignore', message='Passing `gradient_checkpointing` to a config initialization')
|
14 |
+
warnings.filterwarnings('ignore', message='Some weights of the model checkpoint')
|
15 |
+
warnings.filterwarnings('ignore', message='Some weights of Wav2Vec2ForSequenceClassification')
|
16 |
+
|
17 |
+
|
18 |
+
class Wav2VecInference:
|
19 |
+
"""Implementation for wav2vec emotion recognition model."""
|
20 |
+
name = "wav2vec"
|
21 |
+
|
22 |
+
def __init__(self, config, verbose=False):
|
23 |
+
self.config = config
|
24 |
+
self.verbose = verbose
|
25 |
+
self.input_directory = self.config.get_full_path(self.config.paths['data'].raw)
|
26 |
+
self.output_path = self.config.get_full_path(self.config.paths['data'].processed)
|
27 |
+
self.id2label = {
|
28 |
+
"0": "angry", "1": "calm", "2": "disgust", "3": "fearful",
|
29 |
+
"4": "happy", "5": "neutral", "6": "sad", "7": "surprised"
|
30 |
+
}
|
31 |
+
self.model = None
|
32 |
+
self.feature_extractor = None
|
33 |
+
|
34 |
+
def get_emotion_labels(self) -> List[str]:
|
35 |
+
"""Return a list of emotion labels used by this model."""
|
36 |
+
return list(self.id2label.values())
|
37 |
+
|
38 |
+
def load_model(self):
|
39 |
+
"""Load the wav2vec model and feature extractor."""
|
40 |
+
if self.model is None or self.feature_extractor is None:
|
41 |
+
if self.verbose:
|
42 |
+
print("Loading wav2vec2 emotion recognition model...")
|
43 |
+
|
44 |
+
# Suppress warnings during model loading
|
45 |
+
with warnings.catch_warnings():
|
46 |
+
warnings.filterwarnings("ignore")
|
47 |
+
|
48 |
+
# Load the model without the gradient_checkpointing parameter
|
49 |
+
self.model = AutoModelForAudioClassification.from_pretrained(
|
50 |
+
"ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition",
|
51 |
+
ignore_mismatched_sizes=True # This helps with the weights warning
|
52 |
+
).to(self.config.settings.device)
|
53 |
+
|
54 |
+
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
55 |
+
"facebook/wav2vec2-large-xlsr-53"
|
56 |
+
)
|
57 |
+
|
58 |
+
if self.verbose:
|
59 |
+
print("Model loaded successfully.")
|
60 |
+
|
61 |
+
def _predict_emotion_from_segment(self, segment):
|
62 |
+
"""Predict emotions for a single audio segment."""
|
63 |
+
# Convert to mono numpy array
|
64 |
+
waveform = np.array(segment.get_array_of_samples()).astype(np.float32)
|
65 |
+
if segment.channels == 2:
|
66 |
+
waveform = waveform.reshape((-1, 2)).mean(axis=1)
|
67 |
+
waveform = waveform.reshape(-1)
|
68 |
+
|
69 |
+
# Extract features
|
70 |
+
inputs = self.feature_extractor(
|
71 |
+
waveform,
|
72 |
+
sampling_rate=segment.frame_rate,
|
73 |
+
return_tensors="pt",
|
74 |
+
padding=True
|
75 |
+
)
|
76 |
+
inputs = {k: v.to(self.config.settings.device) for k, v in inputs.items()}
|
77 |
+
|
78 |
+
# Get predictions
|
79 |
+
with torch.no_grad():
|
80 |
+
logits = self.model(**inputs).logits
|
81 |
+
probabilities = torch.softmax(logits, dim=1).cpu().squeeze().tolist()
|
82 |
+
|
83 |
+
return {self.id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities)}
|
84 |
+
|
85 |
+
def analyze_emotions_over_time(self, audio_file: str, segment_duration: int,
|
86 |
+
overlap_duration: int) -> list:
|
87 |
+
"""
|
88 |
+
Analyze emotions in chunks over the duration of an audio file.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
audio_file: Path to the audio file
|
92 |
+
segment_duration: Duration of each segment in milliseconds
|
93 |
+
overlap_duration: Overlap between segments in milliseconds
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
List of tuples (start_time, end_time, emotion_probabilities)
|
97 |
+
"""
|
98 |
+
sound = AudioSegment.from_file(audio_file)
|
99 |
+
duration = len(sound)
|
100 |
+
emotions_over_time = []
|
101 |
+
|
102 |
+
start = 0
|
103 |
+
total_segments = 0
|
104 |
+
|
105 |
+
# Calculate total number of segments for progress tracking
|
106 |
+
if self.verbose:
|
107 |
+
temp_start = 0
|
108 |
+
while temp_start + segment_duration <= duration:
|
109 |
+
total_segments += 1
|
110 |
+
temp_start += segment_duration - overlap_duration
|
111 |
+
print(f"Processing {total_segments} segments from audio file...")
|
112 |
+
|
113 |
+
segment_count = 0
|
114 |
+
while start + segment_duration <= duration:
|
115 |
+
segment = sound[start:start + segment_duration]
|
116 |
+
emotion_probabilities = self._predict_emotion_from_segment(segment)
|
117 |
+
emotions_over_time.append((start, start + segment_duration, emotion_probabilities))
|
118 |
+
start += segment_duration - overlap_duration
|
119 |
+
|
120 |
+
segment_count += 1
|
121 |
+
if self.verbose and segment_count % 10 == 0:
|
122 |
+
print(f" Processed {segment_count}/{total_segments} segments...")
|
123 |
+
|
124 |
+
if self.verbose:
|
125 |
+
print(f" Completed processing {segment_count} segments.")
|
126 |
+
|
127 |
+
return emotions_over_time
|
preprocessing/flattening_base.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
|
5 |
+
class BaseFlattening(ABC):
|
6 |
+
name = "base"
|
7 |
+
|
8 |
+
@abstractmethod
|
9 |
+
def flatten_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
10 |
+
pass
|
11 |
+
|
12 |
+
def get_name(self) -> str:
|
13 |
+
return self.name
|
14 |
+
|
15 |
+
def __str__(self):
|
16 |
+
return self.name
|
preprocessing/flattening_categorical.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from sklearn.preprocessing import LabelEncoder
|
3 |
+
|
4 |
+
from preprocessing.flattening_base import BaseFlattening
|
5 |
+
|
6 |
+
|
7 |
+
class CategoricalFlattening(BaseFlattening):
|
8 |
+
def __init__(self):
|
9 |
+
"""Initialize the categorical flattening strategy."""
|
10 |
+
self.name = "categorical"
|
11 |
+
self.label_encoder = LabelEncoder()
|
12 |
+
# Maximum number of positions to encode individually
|
13 |
+
self.max_positions = 1000
|
14 |
+
|
15 |
+
def flatten_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
16 |
+
"""
|
17 |
+
Implement categorical flattening for the combined DataFrame.
|
18 |
+
For each row, identifies the most dominant emotion and creates
|
19 |
+
a fixed-length sequence of these dominant emotions for each patient.
|
20 |
+
Additionally, adds summary features to capture information from longer sequences.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
df: DataFrame containing all inference results
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
DataFrame containing categorical features with encoded emotions
|
27 |
+
"""
|
28 |
+
# Group by filename (patient ID)
|
29 |
+
grouped = df.groupby('filename')
|
30 |
+
results = []
|
31 |
+
|
32 |
+
emotion_columns = [col for col in df.columns
|
33 |
+
if col not in ['start', 'end', 'filename']]
|
34 |
+
|
35 |
+
# Calculate global sequence statistics to determine max positions
|
36 |
+
sequence_lengths = []
|
37 |
+
for _, group in grouped:
|
38 |
+
sequence_lengths.append(len(group))
|
39 |
+
|
40 |
+
# Get the 95th percentile of sequence lengths to determine max_positions
|
41 |
+
if sequence_lengths:
|
42 |
+
# Only adjust max_positions if we have enough data
|
43 |
+
if len(sequence_lengths) > 10:
|
44 |
+
p95_length = int(pd.Series(sequence_lengths).quantile(0.95))
|
45 |
+
# Limit to a reasonable maximum (30) to prevent too many columns
|
46 |
+
self.max_positions = min(p95_length, self.max_positions)
|
47 |
+
|
48 |
+
for filename, group in grouped:
|
49 |
+
# Get the dominant emotion for each time point
|
50 |
+
dominant_emotions = group[emotion_columns].idxmax(axis=1)
|
51 |
+
|
52 |
+
# Create features dictionary for this patient
|
53 |
+
features = {'filename': filename}
|
54 |
+
|
55 |
+
# Add sequence length as a feature
|
56 |
+
features['sequence_length'] = len(dominant_emotions)
|
57 |
+
|
58 |
+
# Add fixed-length position features (up to max_positions)
|
59 |
+
for i in range(1, self.max_positions + 1):
|
60 |
+
if i <= len(dominant_emotions):
|
61 |
+
features[f'emotion_pos_{i}'] = dominant_emotions.iloc[i - 1]
|
62 |
+
else:
|
63 |
+
# For shorter sequences, use a consistent padding value
|
64 |
+
features[f'emotion_pos_{i}'] = "padding"
|
65 |
+
|
66 |
+
# Add summary features for the entire sequence
|
67 |
+
for emotion in emotion_columns:
|
68 |
+
# Count occurrences of each emotion
|
69 |
+
features[f'count_{emotion}'] = (dominant_emotions == emotion).sum()
|
70 |
+
# Calculate proportion of each emotion
|
71 |
+
features[f'prop_{emotion}'] = features[f'count_{emotion}'] / len(dominant_emotions) if len(
|
72 |
+
dominant_emotions) > 0 else 0
|
73 |
+
|
74 |
+
# Add summary of emotion transitions
|
75 |
+
if len(dominant_emotions) > 1:
|
76 |
+
transitions = 0
|
77 |
+
for i in range(len(dominant_emotions) - 1):
|
78 |
+
if dominant_emotions.iloc[i] != dominant_emotions.iloc[i + 1]:
|
79 |
+
transitions += 1
|
80 |
+
features['emotion_transitions'] = transitions
|
81 |
+
features['emotion_transitions_ratio'] = transitions / (len(dominant_emotions) - 1)
|
82 |
+
else:
|
83 |
+
features['emotion_transitions'] = 0
|
84 |
+
features['emotion_transitions_ratio'] = 0
|
85 |
+
|
86 |
+
results.append(features)
|
87 |
+
|
88 |
+
# Create DataFrame from results
|
89 |
+
result_df = pd.DataFrame(results)
|
90 |
+
|
91 |
+
# Encode all emotion position columns
|
92 |
+
emotion_cols = [col for col in result_df.columns if col.startswith('emotion_pos_')]
|
93 |
+
for col in emotion_cols:
|
94 |
+
result_df[col] = self.label_encoder.fit_transform(result_df[col].astype(str))
|
95 |
+
|
96 |
+
return result_df
|
97 |
+
|
98 |
+
def get_emotion_mapping(self) -> dict:
|
99 |
+
"""
|
100 |
+
Get the mapping between encoded values and emotion categories.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Dictionary mapping encoded values to emotion categories
|
104 |
+
"""
|
105 |
+
return dict(zip(
|
106 |
+
self.label_encoder.transform(self.label_encoder.classes_),
|
107 |
+
self.label_encoder.classes_
|
108 |
+
))
|
preprocessing/flattening_minirocket.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.preprocessing import StandardScaler
|
4 |
+
from sktime.transformations.panel.rocket import MiniRocket
|
5 |
+
|
6 |
+
from preprocessing.flattening_base import BaseFlattening
|
7 |
+
|
8 |
+
class MiniRocketFlattening(BaseFlattening):
|
9 |
+
def __init__(self):
|
10 |
+
"""Initialize the MiniRocket flattening strategy."""
|
11 |
+
self.name = "minirocket"
|
12 |
+
self.minirocket = None
|
13 |
+
self.emotion_columns = None
|
14 |
+
self.scaler = StandardScaler()
|
15 |
+
self.min_sequence_length = 9 # MiniRocket requires at least 9 timepoints
|
16 |
+
self.features_per_emotion = None # Will store number of features per emotion
|
17 |
+
|
18 |
+
def _initialize_minirocket(self, data_3d):
|
19 |
+
"""
|
20 |
+
Initialize and fit the MiniRocket transformer.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
data_3d: 3D array with shape (n_instances, n_columns, n_timepoints)
|
24 |
+
"""
|
25 |
+
try:
|
26 |
+
self.minirocket = MiniRocket(
|
27 |
+
random_state=42,
|
28 |
+
n_jobs=1,
|
29 |
+
num_kernels=84,
|
30 |
+
max_dilations_per_kernel=32
|
31 |
+
)
|
32 |
+
# print(f"Initializing MiniRocket with data shape: {data_3d.shape}")
|
33 |
+
self.minirocket.fit(data_3d)
|
34 |
+
# Calculate features per emotion after first transform
|
35 |
+
sample_transform = self.minirocket.transform(data_3d)
|
36 |
+
total_features = sample_transform.shape[1]
|
37 |
+
self.features_per_emotion = total_features // len(self.emotion_columns)
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
print(f"Error initializing MiniRocket: {str(e)}")
|
41 |
+
raise
|
42 |
+
|
43 |
+
def _prepare_3d_array(self, group_data: pd.DataFrame) -> np.ndarray:
|
44 |
+
"""
|
45 |
+
Convert patient group data to 3D array format required by MiniRocket with improved error handling.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
group_data: DataFrame containing one patient's data
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
3D numpy array with shape (n_instances, n_columns, n_timepoints)
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
if self.emotion_columns is None:
|
55 |
+
self.emotion_columns = [col for col in group_data.columns
|
56 |
+
if col not in ['start', 'end', 'filename']]
|
57 |
+
|
58 |
+
# Make sure we have data
|
59 |
+
if group_data.empty:
|
60 |
+
raise ValueError("Empty data array")
|
61 |
+
|
62 |
+
data = group_data[self.emotion_columns].values
|
63 |
+
if data.size == 0:
|
64 |
+
raise ValueError("Empty data array after selecting columns")
|
65 |
+
|
66 |
+
# Handle sequence length, ensuring minimum of 9 timepoints
|
67 |
+
current_length = data.shape[0]
|
68 |
+
if current_length < self.min_sequence_length:
|
69 |
+
# Padding for sequences shorter than minimum
|
70 |
+
pad_length = self.min_sequence_length - current_length
|
71 |
+
data = np.pad(data, ((0, pad_length), (0, 0)), mode='constant')
|
72 |
+
elif current_length > self.min_sequence_length:
|
73 |
+
# Truncate longer sequences
|
74 |
+
data = data[:self.min_sequence_length]
|
75 |
+
|
76 |
+
# Normalize the data with safeguards against division by zero
|
77 |
+
data_mean = data.mean(axis=0)
|
78 |
+
data_std = data.std(axis=0)
|
79 |
+
# Add small epsilon to avoid division by zero
|
80 |
+
data_std = np.where(data_std < 1e-8, 1e-8, data_std)
|
81 |
+
data = (data - data_mean) / data_std
|
82 |
+
data = data.astype(np.float32)
|
83 |
+
|
84 |
+
# Create 3D array
|
85 |
+
data_3d = np.zeros((1, len(self.emotion_columns), data.shape[0]), dtype=np.float32)
|
86 |
+
for i in range(len(self.emotion_columns)):
|
87 |
+
data_3d[0, i, :] = data[:, i]
|
88 |
+
|
89 |
+
return data_3d
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error in _prepare_3d_array: {str(e)}")
|
93 |
+
# Return a safe default array if processing fails
|
94 |
+
empty_array = np.zeros((1, len(self.emotion_columns) if self.emotion_columns else 1,
|
95 |
+
self.min_sequence_length), dtype=np.float32)
|
96 |
+
return empty_array
|
97 |
+
|
98 |
+
def flatten_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
99 |
+
"""Implement MiniRocket flattening with emotion-prefixed features."""
|
100 |
+
grouped = df.groupby('filename')
|
101 |
+
all_features = []
|
102 |
+
first_batch = True
|
103 |
+
|
104 |
+
# Process each patient
|
105 |
+
for filename, group in grouped:
|
106 |
+
try:
|
107 |
+
# Prepare 3D array for this patient
|
108 |
+
data_3d = self._prepare_3d_array(group)
|
109 |
+
|
110 |
+
# Initialize MiniRocket with first batch
|
111 |
+
if first_batch:
|
112 |
+
self._initialize_minirocket(data_3d)
|
113 |
+
first_batch = False
|
114 |
+
|
115 |
+
# Transform the data
|
116 |
+
features = self.minirocket.transform(data_3d)
|
117 |
+
features_array = features.to_numpy()
|
118 |
+
if len(features_array.shape) > 1:
|
119 |
+
features_array = features_array[0]
|
120 |
+
|
121 |
+
# Create features dictionary with emotion prefixes
|
122 |
+
features_dict = {'filename': filename}
|
123 |
+
|
124 |
+
# Distribute features among emotions
|
125 |
+
for emotion_idx, emotion in enumerate(self.emotion_columns):
|
126 |
+
start_idx = emotion_idx * self.features_per_emotion
|
127 |
+
end_idx = start_idx + self.features_per_emotion
|
128 |
+
|
129 |
+
# Add features for this emotion
|
130 |
+
emotion_features = {
|
131 |
+
f'{emotion}_feature_{i}': value
|
132 |
+
for i, value in enumerate(features_array[start_idx:end_idx])
|
133 |
+
}
|
134 |
+
features_dict.update(emotion_features)
|
135 |
+
|
136 |
+
all_features.append(features_dict)
|
137 |
+
|
138 |
+
except Exception as e:
|
139 |
+
print(f"Error processing patient {filename}: {str(e)}")
|
140 |
+
continue
|
141 |
+
|
142 |
+
# Create DataFrame from all features
|
143 |
+
result_df = pd.DataFrame(all_features)
|
144 |
+
|
145 |
+
# Scale the features if we have any
|
146 |
+
feature_cols = [col for col in result_df.columns if 'feature_' in col]
|
147 |
+
if feature_cols:
|
148 |
+
features_array = result_df[feature_cols].values
|
149 |
+
scaled_features = self.scaler.fit_transform(features_array)
|
150 |
+
|
151 |
+
# Update the DataFrame with scaled features
|
152 |
+
for i, col in enumerate(feature_cols):
|
153 |
+
result_df[col] = scaled_features[:, i]
|
154 |
+
|
155 |
+
return result_df
|
preprocessing/flattening_statistical.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from preprocessing.flattening_base import BaseFlattening
|
3 |
+
|
4 |
+
|
5 |
+
class StatisticalFlattening(BaseFlattening):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
self.name = "statistical"
|
9 |
+
|
10 |
+
def flatten_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
11 |
+
"""
|
12 |
+
Implement statistical flattening for the combined DataFrame.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
df: DataFrame containing all inference results
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
DataFrame containing statistical features for each unique filename
|
19 |
+
"""
|
20 |
+
grouped = df.groupby('filename')
|
21 |
+
|
22 |
+
results = []
|
23 |
+
for filename, group in grouped:
|
24 |
+
stats = {'filename': filename}
|
25 |
+
|
26 |
+
# Calculate statistics for each feature column
|
27 |
+
for column in group.columns:
|
28 |
+
if column in ['start', 'end', 'filename']:
|
29 |
+
continue
|
30 |
+
|
31 |
+
# Calculate statistical features
|
32 |
+
stats[f'{column}_mean'] = group[column].mean()
|
33 |
+
stats[f'{column}_std'] = group[column].std()
|
34 |
+
stats[f'{column}_max'] = group[column].max()
|
35 |
+
stats[f'{column}_min'] = group[column].min()
|
36 |
+
stats[f'{column}_mode'] = group[column].mode().iloc[0]
|
37 |
+
stats[f'{column}_skewness'] = group[column].skew()
|
38 |
+
stats[f'{column}_kurtosis'] = group[column].kurtosis()
|
39 |
+
stats[f'{column}_median'] = group[column].median()
|
40 |
+
stats[f'{column}_q1'] = group[column].quantile(0.25)
|
41 |
+
stats[f'{column}_q3'] = group[column].quantile(0.75)
|
42 |
+
stats[f'{column}_iqr'] = group[column].quantile(0.75) - group[column].quantile(0.25)
|
43 |
+
stats[f'{column}_range'] = group[column].max() - group[column].min()
|
44 |
+
stats[f'{column}_variance'] = group[column].var()
|
45 |
+
stats[f'{column}_sem'] = group[column].sem()
|
46 |
+
stats[f'{column}_cv'] = group[column].std() / group[column].mean()
|
47 |
+
|
48 |
+
results.append(stats)
|
49 |
+
|
50 |
+
return pd.DataFrame(results)
|
utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils/__init__.py
|
2 |
+
from utils.config import ProjectConfig
|
3 |
+
from utils.config_types import ChunkLength
|
4 |
+
|
5 |
+
__all__ = ["ProjectConfig", "ChunkLength"]
|
utils/config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from .config_types import (
|
4 |
+
MachinePaths, Labels, ModelPaths, Settings,
|
5 |
+
Data, ChunkLength, ChunkConfig
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
class ProjectConfig:
|
10 |
+
"""Project configuration class."""
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
self.machine_paths = MachinePaths()
|
14 |
+
self.paths = {
|
15 |
+
'labels': Labels(),
|
16 |
+
'models': ModelPaths(),
|
17 |
+
'data': Data()
|
18 |
+
}
|
19 |
+
self.settings = Settings()
|
20 |
+
|
21 |
+
def get_full_path(self, relative_path: str) -> str:
|
22 |
+
"""Combine machine base path with relative path"""
|
23 |
+
base_path = self.machine_paths.get_current_machine_path()
|
24 |
+
return os.path.join(base_path, relative_path)
|
25 |
+
|
26 |
+
def get_model_length_path(self, model_name: str, length: ChunkLength) -> str:
|
27 |
+
"""Get the folder location for a specific model and length"""
|
28 |
+
base_path = self.machine_paths.get_current_machine_path()
|
29 |
+
folder_name = f"{model_name}-{str(length)}"
|
30 |
+
return os.path.join(base_path, 'processed', folder_name)
|
31 |
+
|
32 |
+
def get_chunk_params(self, length: ChunkLength) -> ChunkConfig:
|
33 |
+
"""Get chunk configuration for a specific length."""
|
34 |
+
if length not in self.paths['data'].chunks:
|
35 |
+
raise ValueError(f"Invalid length {length}. Must be one of {list(self.paths['data'].chunks.keys())}")
|
36 |
+
return self.paths['data'].chunks[length]
|
37 |
+
|
38 |
+
def get_available_lengths(self) -> List[ChunkLength]:
|
39 |
+
"""Get list of available chunk lengths."""
|
40 |
+
return list(self.paths['data'].chunks.keys())
|
41 |
+
|
42 |
+
def get_model_path(self, name, length, task):
|
43 |
+
pass
|
utils/config_types.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from enum import Enum
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class ChunkLength(str, Enum):
|
10 |
+
"""Enumeration of available chunk lengths."""
|
11 |
+
LENGTH_1_5 = "1_5"
|
12 |
+
LENGTH_3_5 = "3_5"
|
13 |
+
LENGTH_4_5 = "4_5"
|
14 |
+
|
15 |
+
def __str__(self) -> str:
|
16 |
+
return self.value
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class ChunkConfig:
|
21 |
+
"""Configuration for audio chunk processing."""
|
22 |
+
length: ChunkLength
|
23 |
+
segment_duration: int # in milliseconds
|
24 |
+
overlap_duration: int # in milliseconds
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class MachinePaths:
|
29 |
+
"""Base paths for different machines/environments."""
|
30 |
+
default: Path = Path('./data')
|
31 |
+
|
32 |
+
def get_current_machine_path(self) -> Path:
|
33 |
+
"""Get base path for current machine based on environment variable."""
|
34 |
+
# Use EMOTION_DATA_PATH environment variable if set, otherwise use default
|
35 |
+
env_path = os.getenv('EMOTION_DATA_PATH')
|
36 |
+
if env_path:
|
37 |
+
return Path(env_path)
|
38 |
+
return self.default
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class Labels:
|
43 |
+
"""Paths to label files."""
|
44 |
+
demographic: str = 'processed/labels/demographic.csv'
|
45 |
+
multi: str = 'processed/labels/multiclass.csv'
|
46 |
+
regression: str = 'processed/labels/regression.csv'
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class ModelPaths:
|
51 |
+
"""Paths for model storage."""
|
52 |
+
base_dir: str = 'models/'
|
53 |
+
checkpoints: str = 'models/checkpoints/'
|
54 |
+
configs: str = 'models/configs/'
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class Settings:
|
59 |
+
"""General framework settings."""
|
60 |
+
seed: int = 42
|
61 |
+
batch_size: int = 16
|
62 |
+
num_workers: int = -1
|
63 |
+
device: str = None
|
64 |
+
|
65 |
+
def __post_init__(self):
|
66 |
+
if self.device is None:
|
67 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
68 |
+
|
69 |
+
|
70 |
+
@dataclass
|
71 |
+
class Data:
|
72 |
+
"""Data paths and chunk configurations."""
|
73 |
+
raw: str = 'raw/'
|
74 |
+
processed: str = 'processed/inference/'
|
75 |
+
table: str = 'processed/tables/'
|
76 |
+
splits: str = 'processed/splits/'
|
77 |
+
results: str = 'results/'
|
78 |
+
chunks: Dict[ChunkLength, ChunkConfig] = None
|
79 |
+
|
80 |
+
def __post_init__(self):
|
81 |
+
if self.chunks is None:
|
82 |
+
self.chunks = {
|
83 |
+
ChunkLength.LENGTH_1_5: ChunkConfig(
|
84 |
+
length=ChunkLength.LENGTH_1_5,
|
85 |
+
segment_duration=1500,
|
86 |
+
overlap_duration=500
|
87 |
+
),
|
88 |
+
ChunkLength.LENGTH_3_5: ChunkConfig(
|
89 |
+
length=ChunkLength.LENGTH_3_5,
|
90 |
+
segment_duration=3500,
|
91 |
+
overlap_duration=500
|
92 |
+
),
|
93 |
+
ChunkLength.LENGTH_4_5: ChunkConfig(
|
94 |
+
length=ChunkLength.LENGTH_4_5,
|
95 |
+
segment_duration=4500,
|
96 |
+
overlap_duration=500
|
97 |
+
)
|
98 |
+
}
|
utils/logger.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
# Configure logging
|
4 |
+
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
5 |
+
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
|
6 |
+
logger = logging.getLogger(__name__)
|
utils/tabular_transformation.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.model_selection import StratifiedKFold, KFold
|
4 |
+
|
5 |
+
from utils.config_types import LearningTask
|
6 |
+
from utils.logger import logger
|
7 |
+
from preprocessing.flattening_base import BaseFlattening
|
8 |
+
|
9 |
+
|
10 |
+
class TabularTransformer:
|
11 |
+
"""Static class for transforming tabular data in various ways."""
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def get_id(filename: str) -> int:
|
15 |
+
id_value = filename.split('_')[-1].split('-')[0]
|
16 |
+
return int(id_value)
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def attach_demographic(df_data: pd.DataFrame, demographic_path: str) -> pd.DataFrame:
|
20 |
+
"""
|
21 |
+
Attaches demographic information to the data DataFrame.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
df_data: DataFrame containing the processed data
|
25 |
+
demographic_path: Path to the demographic CSV file
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
DataFrame with demographic information attached
|
29 |
+
"""
|
30 |
+
df_demographic = pd.read_csv(demographic_path)
|
31 |
+
|
32 |
+
df_data['id'] = df_data['filename'].apply(TabularTransformer.get_id)
|
33 |
+
merged_df = pd.merge(df_data, df_demographic, on='id')
|
34 |
+
merged_df.drop(columns=['id'], inplace=True)
|
35 |
+
return merged_df
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def transform_to_binary(df: pd.DataFrame) -> pd.DataFrame:
|
39 |
+
"""
|
40 |
+
Transforms the DataFrame into a binary classification problem.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
df: DataFrame with demographic information
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
DataFrame with binary target labels (0 or 1)
|
47 |
+
"""
|
48 |
+
df['target'] = df['filename'].apply(lambda x: 1 if x.split('_')[0] == 'dementia' else 0)
|
49 |
+
return df
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def transform_to_multiclass(df: pd.DataFrame, path_to_label: str) -> pd.DataFrame:
|
53 |
+
"""
|
54 |
+
Transforms the DataFrame into a multiclass classification problem.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
df: DataFrame with demographic information
|
58 |
+
path_to_label: Path to the multiclass labels CSV file
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
DataFrame with multiclass target labels
|
62 |
+
"""
|
63 |
+
df_labels = pd.read_csv(path_to_label)
|
64 |
+
df_labels = df_labels[['ID', 'target_multi']]
|
65 |
+
df_labels.rename(columns={'target_multi': 'target'}, inplace=True)
|
66 |
+
df['ID'] = df['filename']
|
67 |
+
df = pd.merge(df, df_labels, on='ID')
|
68 |
+
df.drop(columns=['ID'], inplace=True)
|
69 |
+
return df
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def transform_to_regression(df: pd.DataFrame, path_to_label: str) -> pd.DataFrame:
|
73 |
+
"""
|
74 |
+
Transforms the DataFrame into a regression problem.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
df: DataFrame with demographic information
|
78 |
+
path_to_label: Path to the regression labels CSV file
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
DataFrame with regression target labels
|
82 |
+
"""
|
83 |
+
df_labels = pd.read_csv(path_to_label)
|
84 |
+
df_labels.rename(columns={'mms': 'target'}, inplace=True)
|
85 |
+
|
86 |
+
df['id'] = df['filename'].apply(TabularTransformer.get_id)
|
87 |
+
df = pd.merge(df, df_labels, on='id')
|
88 |
+
df.drop(columns=['id'], inplace=True)
|
89 |
+
return df
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def process_and_transform(df: pd.DataFrame,
|
93 |
+
transform_func: callable,
|
94 |
+
label_path_func: Optional[callable]) -> pd.DataFrame:
|
95 |
+
"""Process and transform data based on the task requirements."""
|
96 |
+
if label_path_func:
|
97 |
+
return transform_func(df.copy(), label_path_func())
|
98 |
+
return transform_func(df.copy())
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def get_flattened_data(strategy: BaseFlattening, file_path: str, demographic_path: Optional[str] = None) -> pd.DataFrame:
|
102 |
+
"""Get flattened data with optional demographic information."""
|
103 |
+
logger.info(f"Flattening data with strategy: {strategy.get_name()}")
|
104 |
+
# Flatten the data
|
105 |
+
df_inference = pd.read_csv(file_path) # Read the inference data
|
106 |
+
df_flatten = strategy.flatten_dataframe(df_inference) # Invoke the flattening strategy method
|
107 |
+
|
108 |
+
# Check if I want to attach demographic information
|
109 |
+
if demographic_path:
|
110 |
+
return TabularTransformer.attach_demographic(df_flatten, demographic_path)
|
111 |
+
|
112 |
+
# Return the flattened data
|
113 |
+
return df_flatten
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def get_memory_usage(df: pd.DataFrame) -> str:
|
117 |
+
"""Get memory usage of dataframe in a human-readable format."""
|
118 |
+
memory_usage = df.memory_usage(deep=True).sum()
|
119 |
+
if memory_usage < 1024:
|
120 |
+
return f"{memory_usage} bytes"
|
121 |
+
elif memory_usage < 1024 ** 2:
|
122 |
+
return f"{memory_usage / 1024:.2f} KB"
|
123 |
+
elif memory_usage < 1024 ** 3:
|
124 |
+
return f"{memory_usage / 1024 ** 2:.2f} MB"
|
125 |
+
else:
|
126 |
+
return f"{memory_usage / 1024 ** 3:.2f} GB"
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def generate_folds(inference_data, folds, task, tasks):
|
130 |
+
df_folds = inference_data.groupby('filename').agg(list).reset_index()
|
131 |
+
|
132 |
+
# Check if the task is binary, multiclass, or regression
|
133 |
+
df_folds = df_folds[['filename']]
|
134 |
+
|
135 |
+
# Attach labels
|
136 |
+
df_task = TabularTransformer.process_and_transform(
|
137 |
+
df_folds,
|
138 |
+
tasks[task][0],
|
139 |
+
tasks[task][2]
|
140 |
+
)
|
141 |
+
|
142 |
+
# Split into features and target
|
143 |
+
target_column = df_task['target']
|
144 |
+
file_column = df_task['filename']
|
145 |
+
|
146 |
+
# Create the folds
|
147 |
+
fold_object = None
|
148 |
+
if task == LearningTask.BINARY or task == LearningTask.MULTICLASS:
|
149 |
+
fold_object = StratifiedKFold(n_splits=folds, shuffle=True, random_state=42)
|
150 |
+
elif task == LearningTask.REGRESSION:
|
151 |
+
fold_object = KFold(n_splits=folds, shuffle=True, random_state=42)
|
152 |
+
|
153 |
+
folds_data = []
|
154 |
+
for fold, (train_index, test_index) in enumerate(fold_object.split(file_column, target_column)):
|
155 |
+
current_fold = {
|
156 |
+
'fold': fold,
|
157 |
+
'train': file_column.iloc[train_index].tolist(),
|
158 |
+
'test': file_column.iloc[test_index].tolist(),
|
159 |
+
}
|
160 |
+
|
161 |
+
folds_data.append(current_fold)
|
162 |
+
|
163 |
+
# Turn the folds into a DataFrame
|
164 |
+
fold_df = pd.DataFrame(folds_data, columns=['fold', 'train', 'test'])
|
165 |
+
return fold_df
|