sivdma commited on
Commit
61c258f
·
verified ·
1 Parent(s): 96bea8e

Upload 29 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/example.wav filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+ import tempfile
7
+ import librosa
8
+ import soundfile as sf # Add this import
9
+ import numpy as np
10
+
11
+ # Suppress warnings for a cleaner interface
12
+ warnings.filterwarnings('ignore')
13
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
14
+ os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
15
+
16
+ # Import your framework's components
17
+ from emotion_analysis_framework import EmotionAnalysisFramework, PatientData
18
+
19
+ # --- Initialization ---
20
+ # Initialize the framework. It will look for the 'model_weights' directory.
21
+ # Make sure this directory is in your Hugging Face Space repository.
22
+ try:
23
+ print("Initializing Emotion Analysis Framework...")
24
+ # The model_dir should point to the directory where your models are stored.
25
+ # In a Hugging Face Space, this will be relative to the app.py file.
26
+ framework = EmotionAnalysisFramework(model_dir="./model_weights", verbose=False)
27
+ print("Framework initialized successfully!")
28
+ FRAMEWORK_INITIALIZED = True
29
+ except Exception as e:
30
+ print(f"FATAL: Error initializing framework: {e}")
31
+ print("Please ensure the 'model_weights' directory is present and contains the model files.")
32
+ FRAMEWORK_INITIALIZED = False
33
+ # Define a placeholder framework to avoid crashing the app
34
+ framework = None
35
+
36
+
37
+ # --- Prediction Function ---
38
+ def analyze_emotion(audio_file, sex, race, education, age):
39
+ """
40
+ This function is the core of the Gradio app. It takes user inputs,
41
+ runs the prediction, and returns the formatted results.
42
+ """
43
+ if not FRAMEWORK_INITIALIZED or framework is None:
44
+ return "Error: Framework not initialized. Check logs.", "", "", "", "", ""
45
+
46
+ if audio_file is None:
47
+ return "Please upload an audio file.", "", "", "", "", ""
48
+
49
+ try:
50
+ # Gradio provides the audio file as a temporary file path
51
+ audio_path = audio_file
52
+
53
+ # Take that audio file and make sure that the sr is 16000 Hz
54
+ if isinstance(audio_path, str):
55
+ # Ensure the audio file is a valid path
56
+ if not Path(audio_path).is_file():
57
+ return "Invalid audio file path.", "", "", "", "", ""
58
+ else:
59
+ # If the audio file is not a string, it might be a temporary file object
60
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
61
+ temp_file.write(audio_file.read())
62
+ audio_path = temp_file.name
63
+
64
+ # Ensure the audio file is in the correct format (e.g., .wav)
65
+ if not audio_path.endswith('.wav'):
66
+ return "Please upload a valid .wav audio file.", "", "", "", "", ""
67
+
68
+ # Validate demographic inputs
69
+ if not isinstance(education, int) or not (0 <= education <= 30):
70
+ return "Education must be an integer between 0 and 30.", "", "", "", "", ""
71
+
72
+ # Check the audio sr and if its not 16000 Hz, resample it
73
+ try:
74
+ audio_data, sr = librosa.load(audio_path, sr=None)
75
+ if sr != 16000:
76
+ # print(f"Resampling audio from {sr}Hz to 16000Hz")
77
+ audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
78
+
79
+ # Create a temporary file for the resampled audio
80
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
81
+ temp_path = temp_file.name
82
+ # Use soundfile to save the resampled audio
83
+ sf.write(temp_path, audio_data, 16000)
84
+ audio_path = temp_path
85
+ except Exception as e:
86
+ return f"Error processing audio file: {e}", "", "", "", "", ""
87
+
88
+ # Create a PatientData object with the inputs
89
+ patient_demographics = {
90
+ 'sex': 0 if sex == "Female" else 1,
91
+ 'race': 1 if race == "White" else 0, # Assuming 1 for White, 0 for others as an example
92
+ 'educ': int(education),
93
+ 'entryage': int(age)
94
+ }
95
+
96
+ patient = PatientData(
97
+ patient_id="gradio_user",
98
+ audio_path=audio_path,
99
+ demographics=patient_demographics
100
+ )
101
+
102
+ # Run prediction for all tasks
103
+ results = framework.predict(patient)
104
+
105
+ # --- Format the results for display ---
106
+ # Binary Classification Results
107
+ binary_res = results.get('binary')
108
+ if binary_res and 'error' not in binary_res.predictions:
109
+ binary_label = f"Prediction: {binary_res.predictions.get('label', 'N/A')}"
110
+ binary_confidence = f"Confidence: {binary_res.confidence:.2%}"
111
+ else:
112
+ binary_label = "Binary analysis failed."
113
+ binary_confidence = str(binary_res.predictions.get('error', '')) if binary_res else "Unknown error."
114
+
115
+ # Multiclass Classification Results
116
+ multiclass_res = results.get('multiclass')
117
+ if multiclass_res and 'error' not in multiclass_res.predictions:
118
+ multiclass_label = f"Prediction: {multiclass_res.predictions.get('label', 'N/A')}"
119
+ multiclass_confidence = f"Confidence: {multiclass_res.confidence:.2%}"
120
+ else:
121
+ multiclass_label = "Multiclass analysis failed."
122
+ multiclass_confidence = str(
123
+ multiclass_res.predictions.get('error', '')) if multiclass_res else "Unknown error."
124
+
125
+ # Regression Results
126
+ regression_res = results.get('regression')
127
+ if regression_res and 'error' not in regression_res.predictions:
128
+ mmse_score = f"Predicted MMSE Score: {regression_res.predictions.get('mmse_score', 0):.2f}"
129
+ mmse_std = f"Standard Deviation: ±{regression_res.predictions.get('std', 0):.2f}"
130
+ else:
131
+ mmse_score = "MMSE prediction failed."
132
+ mmse_std = str(regression_res.predictions.get('error', '')) if regression_res else "Unknown error."
133
+
134
+ return binary_label, binary_confidence, multiclass_label, multiclass_confidence, mmse_score, mmse_std
135
+
136
+ except Exception as e:
137
+ print(f"An error occurred during prediction: {e}")
138
+ return f"An error occurred: {e}", "", "", "", "", ""
139
+
140
+
141
+ # --- Gradio Interface Definition ---
142
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
143
+ gr.Markdown("# 🧠 Emotion and Cognitive Health Analysis from Speech")
144
+ gr.Markdown(
145
+ "Upload a patient's audio recording and provide their demographic information to get an analysis. "
146
+ "This tool provides predictions for Alzheimer's Disease (AD) vs. Healthy Control (HC), a multiclass "
147
+ "prediction (HC/MCI/AD), and an estimated MMSE score."
148
+ )
149
+
150
+ with gr.Row():
151
+ with gr.Column(scale=1):
152
+ gr.Markdown("### Inputs")
153
+ # Audio Input
154
+ audio_input = gr.Audio(type="filepath", label="Upload Patient Audio (.wav)")
155
+
156
+ # Demographics Inputs
157
+ sex_input = gr.Radio(["Female", "Male"], label="Sex")
158
+ race_input = gr.Radio(["White", "Other"], label="Race") # Adjust as needed
159
+ education_input = gr.Slider(minimum=0, maximum=30, step=1, value=16, label="Years of Education")
160
+ age_input = gr.Slider(minimum=40, maximum=100, step=1, value=65, label="Age at Entry")
161
+
162
+ analyze_btn = gr.Button("Analyze", variant="primary")
163
+
164
+ with gr.Column(scale=2):
165
+ gr.Markdown("### Analysis Results")
166
+ # Binary Classification Output
167
+ with gr.Group():
168
+ gr.Label("Binary Classification (AD vs. HC)")
169
+ binary_output_label = gr.Textbox(label="Result")
170
+ binary_output_confidence = gr.Textbox(label="Confidence")
171
+
172
+ # Multiclass Classification Output
173
+ with gr.Group():
174
+ gr.Label("Multiclass Classification (HC vs. MCI vs. AD)")
175
+ multiclass_output_label = gr.Textbox(label="Result")
176
+ multiclass_output_confidence = gr.Textbox(label="Confidence")
177
+
178
+ # Regression Output
179
+ with gr.Group():
180
+ gr.Label("MMSE Score Regression")
181
+ regression_output_score = gr.Textbox(label="MMSE Score")
182
+ regression_output_std = gr.Textbox(label="Standard Deviation")
183
+
184
+ # Connect the button to the prediction function
185
+ analyze_btn.click(
186
+ fn=analyze_emotion,
187
+ inputs=[audio_input, sex_input, race_input, education_input, age_input],
188
+ outputs=[
189
+ binary_output_label,
190
+ binary_output_confidence,
191
+ multiclass_output_label,
192
+ multiclass_output_confidence,
193
+ regression_output_score,
194
+ regression_output_std
195
+ ]
196
+ )
197
+
198
+ gr.Markdown("---")
199
+ gr.Markdown(
200
+ "**Disclaimer:** This tool is for research purposes only and is not a substitute for professional medical advice, diagnosis, or treatment."
201
+ )
202
+ gr.Examples(
203
+ examples=[
204
+ ["./example/example.wav", "Female", "White", 16, 58],
205
+ ],
206
+ inputs=[audio_input, sex_input, race_input, education_input, age_input],
207
+ fn=analyze_emotion,
208
+ outputs=[
209
+ binary_output_label,
210
+ binary_output_confidence,
211
+ multiclass_output_label,
212
+ multiclass_output_confidence,
213
+ regression_output_score,
214
+ regression_output_std
215
+ ],
216
+ cache_examples=True
217
+ )
218
+
219
+ if __name__ == "__main__":
220
+ # To run locally:
221
+ # 1. Make sure you have all dependencies from requirements.txt installed.
222
+ # 2. Place your 'model_weights' and 'example' folders in the same directory as this script.
223
+ # 3. Run 'python app.py' in your terminal.
224
+ demo.launch()
emotion_analysis_framework.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Emotion Analysis Framework
3
+ A framework for analyzing emotions from patient audio recordings using wav2vec2 models.
4
+
5
+ Author: Marek Sviderski
6
+
7
+ This framework supports three main tasks:
8
+ - Binary classification: Distinguishing between Alzheimer's Disease (AD) and Healthy Control (HC)
9
+ - Multiclass classification: Classifying into HC, Mild Cognitive Impairment (MCI), and AD
10
+ - Regression: Predicting the Mini-Mental State Examination (MMSE) score
11
+
12
+ This code is designed to be modular and extensible, allowing for easy integration of new models and strategies.
13
+
14
+ It uses dataclasses for structured data representation and provides methods for feature extraction, model loading, and predictions.
15
+ """
16
+
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Union, Any
20
+ from dataclasses import dataclass, field
21
+ import numpy as np
22
+ import pandas as pd
23
+ import torch
24
+ import joblib
25
+ import warnings
26
+
27
+ from utils.config import ProjectConfig
28
+ from utils.config_types import ChunkLength
29
+ from models.inference_wav2vec import Wav2VecInference
30
+ from preprocessing.flattening_statistical import StatisticalFlattening
31
+ from preprocessing.flattening_categorical import CategoricalFlattening
32
+ from preprocessing.flattening_minirocket import MiniRocketFlattening
33
+
34
+ warnings.filterwarnings('ignore', category=UserWarning, module='xgboost')
35
+ warnings.filterwarnings('ignore', message='Some weights of the model checkpoint')
36
+ warnings.filterwarnings('ignore', message='Some weights of Wav2Vec2ForSequenceClassification')
37
+
38
+
39
+
40
+ @dataclass
41
+ class PatientData:
42
+ """Data class for patient information"""
43
+ patient_id: str
44
+ audio_path: str
45
+ demographics: Dict[str, Any] = field(default_factory=dict)
46
+
47
+
48
+ @dataclass
49
+ class PredictionResult:
50
+ """Data class for prediction results"""
51
+ patient_id: str
52
+ task: str
53
+ predictions: Dict[str, Any]
54
+ probabilities: Optional[Dict[str, Any]] = None
55
+ confidence: Optional[float] = None
56
+ metadata: Dict[str, Any] = field(default_factory=dict)
57
+
58
+ def summary(self) -> str:
59
+ """Return a clean summary of the prediction"""
60
+ if self.task == 'binary':
61
+ return f"Binary: {self.predictions['label']} (confidence: {self.confidence:.2f})"
62
+ elif self.task == 'multiclass':
63
+ return f"Multiclass: {self.predictions['label']} (confidence: {self.confidence:.2f})"
64
+ elif self.task == 'regression':
65
+ return f"MMSE Score: {self.predictions['mmse_score']:.1f} ± {self.predictions['std']:.1f}"
66
+ else:
67
+ return str(self.predictions)
68
+
69
+
70
+ class EmotionAnalysisFramework:
71
+ """
72
+ End-to-end framework for emotion analysis from patient recordings.
73
+
74
+ This framework provides three prediction tasks:
75
+ - Binary classification: AD vs Healthy Control (HC)
76
+ - Multiclass classification: HC vs MCI vs AD
77
+ - Regression: MMSE score prediction
78
+
79
+ Args:
80
+ config_path: Optional path to custom configuration
81
+ model_dir: Path to directory containing model weights
82
+ verbose: Whether to print detailed progress information
83
+ """
84
+
85
+ def __init__(self, config_path: Optional[str] = None,
86
+ model_dir: Optional[str] = None,
87
+ verbose: bool = False):
88
+ self.config = ProjectConfig(config_path) if config_path else ProjectConfig()
89
+ self.model_dir = model_dir
90
+ self.verbose = verbose
91
+ self.wav2vec_model = None
92
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
93
+ self.models = {
94
+ 'binary': {},
95
+ 'multiclass': {},
96
+ 'regression': {}
97
+ }
98
+ self.strategies = {}
99
+ self._initialize_strategies()
100
+ self._load_models()
101
+
102
+ def _log(self, message: str):
103
+ """Print message only if verbose mode is enabled"""
104
+ if self.verbose:
105
+ print(message)
106
+
107
+ def _initialize_strategies(self):
108
+ """Initialize all flattening strategies"""
109
+ self.strategies = {
110
+ 'statistical': StatisticalFlattening(),
111
+ 'categorical': CategoricalFlattening(),
112
+ 'minirocket': MiniRocketFlattening()
113
+ }
114
+
115
+ def _load_models(self):
116
+ """Load all trained models"""
117
+ if not self.model_dir:
118
+ # Try to find models in package directory
119
+ package_dir = os.path.dirname(os.path.abspath(__file__))
120
+ self.model_dir = os.path.join(package_dir, "model_weights")
121
+
122
+ if not os.path.exists(self.model_dir):
123
+ raise ValueError(f"Model directory not found: {self.model_dir}")
124
+
125
+ # Load models for each task
126
+ self._load_task_models('binary', os.path.join(self.model_dir, "binary"))
127
+ self._load_task_models('multiclass', os.path.join(self.model_dir, "multiclass"))
128
+ self._load_task_models('regression', os.path.join(self.model_dir, "regression"))
129
+
130
+ # Verify models were loaded
131
+ for task in ['binary', 'multiclass', 'regression']:
132
+ if not self.models[task]:
133
+ raise ValueError(f"No {task} models found in {self.model_dir}")
134
+
135
+ def _load_task_models(self, task: str, path: str):
136
+ """Load models for a specific task"""
137
+ if not os.path.exists(path):
138
+ self._log(f"Warning: {task} model path not found: {path}")
139
+ return
140
+
141
+ model_type = 'simple' if task in ['binary', 'regression'] else 'fusion'
142
+ self.models[task][model_type] = {}
143
+
144
+ model_files = [f for f in os.listdir(path)
145
+ if f.startswith('model_fold_') and f.endswith('.joblib')]
146
+
147
+ for file in model_files:
148
+ fold_num = file.split('_')[-1].replace('.joblib', '')
149
+ model_path = os.path.join(path, file)
150
+ try:
151
+ self.models[task][model_type][fold_num] = joblib.load(model_path)
152
+ self._log(f"Loaded {task} {model_type} model fold {fold_num}")
153
+ except Exception as e:
154
+ self._log(f"Error loading model {model_path}: {e}")
155
+
156
+ def _extract_wav2vec_features(self, audio_path: str, chunk_length: ChunkLength) -> pd.DataFrame:
157
+ """Extract wav2vec features from audio file"""
158
+ if self.wav2vec_model is None:
159
+ self.wav2vec_model = Wav2VecInference(self.config, verbose=self.verbose)
160
+ self.wav2vec_model.load_model()
161
+
162
+ chunk_config = self.config.get_chunk_params(chunk_length)
163
+
164
+ emotions_over_time = self.wav2vec_model.analyze_emotions_over_time(
165
+ audio_path,
166
+ segment_duration=chunk_config.segment_duration,
167
+ overlap_duration=chunk_config.overlap_duration
168
+ )
169
+
170
+ rows = []
171
+ for start, end, emotions in emotions_over_time:
172
+ rows.append({
173
+ 'filename': str(Path(audio_path).stem),
174
+ 'start': start,
175
+ 'end': end,
176
+ **emotions
177
+ })
178
+
179
+ return pd.DataFrame(rows)
180
+
181
+ def _prepare_features(self, patient_data: PatientData, task: str) -> Dict[str, pd.DataFrame]:
182
+ """Prepare features for a specific task - FIXED VERSION"""
183
+ prepared_data = {}
184
+
185
+ if task == 'binary':
186
+ # Binary uses statistical flattening with 3.5s chunks
187
+ df_3_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_3_5)
188
+ flattened = self.strategies['statistical'].flatten_dataframe(df_3_5)
189
+
190
+ # Add demographics more efficiently
191
+ if patient_data.demographics:
192
+ # Create a copy and add all demographics at once
193
+ demo_df = pd.DataFrame([patient_data.demographics] * len(flattened))
194
+ flattened = pd.concat([flattened, demo_df], axis=1)
195
+
196
+ prepared_data['simple'] = flattened
197
+
198
+ elif task == 'multiclass':
199
+ # Multiclass uses categorical flattening with different chunk lengths
200
+ df_1_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_1_5)
201
+ df_4_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_4_5)
202
+
203
+ flattened_1_5 = self.strategies['categorical'].flatten_dataframe(df_1_5)
204
+ flattened_4_5 = self.strategies['categorical'].flatten_dataframe(df_4_5)
205
+
206
+ if patient_data.demographics:
207
+ # Add demographics efficiently
208
+ demo_df = pd.DataFrame([patient_data.demographics])
209
+ flattened_1_5 = pd.concat([flattened_1_5, demo_df], axis=1)
210
+ flattened_4_5 = pd.concat([flattened_4_5, demo_df], axis=1)
211
+
212
+ prepared_data['chunk1'] = flattened_1_5
213
+ prepared_data['chunk2'] = flattened_4_5
214
+
215
+ elif task == 'regression':
216
+ # Regression uses minirocket for 1.5s and 3.5s, categorical for 4.5s
217
+ df_1_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_1_5)
218
+ df_3_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_3_5)
219
+ df_4_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_4_5)
220
+
221
+ flattened_1_5 = self.strategies['minirocket'].flatten_dataframe(df_1_5)
222
+ flattened_3_5 = self.strategies['minirocket'].flatten_dataframe(df_3_5)
223
+ flattened_4_5 = self.strategies['categorical'].flatten_dataframe(df_4_5)
224
+
225
+ if patient_data.demographics:
226
+ # Add demographics efficiently to 4_5 only
227
+ demo_df = pd.DataFrame([patient_data.demographics])
228
+ flattened_4_5 = pd.concat([flattened_4_5, demo_df], axis=1)
229
+
230
+ prepared_data['chunk_1_5'] = flattened_1_5
231
+ prepared_data['chunk_3_5'] = flattened_3_5
232
+ prepared_data['chunk_4_5_demo'] = flattened_4_5
233
+
234
+ return prepared_data
235
+
236
+ def _predict_binary(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
237
+ """Make binary classification predictions - FIXED"""
238
+ model_type = 'simple'
239
+
240
+ if model_type not in self.models['binary'] or not self.models['binary'][model_type]:
241
+ raise ValueError(f"No binary models loaded")
242
+
243
+ fold_predictions = []
244
+ fold_probabilities = []
245
+
246
+ # Get the patient ID from features
247
+ patient_id = 'unknown'
248
+ if 'filename' in features[model_type].columns:
249
+ patient_id = features[model_type]['filename'].iloc[0]
250
+
251
+ # Get predictions from all folds
252
+ for fold_num, model in self.models['binary'][model_type].items():
253
+ try:
254
+ X = features[model_type]
255
+
256
+ # Ensure correct feature order
257
+ if hasattr(model, 'feature_names_in_'):
258
+ # Only use features that the model was trained on
259
+ model_features = [f for f in model.feature_names_in_ if f in X.columns]
260
+ X = X[model_features]
261
+
262
+ pred = model.predict(X)
263
+ pred_proba = model.predict_proba(X)
264
+
265
+ fold_predictions.append(pred[0])
266
+ fold_probabilities.append(pred_proba[0])
267
+ except Exception as e:
268
+ self._log(f"Error in fold {fold_num}: {e}")
269
+ continue
270
+
271
+ if not fold_predictions:
272
+ raise ValueError("No successful predictions from any fold")
273
+
274
+ # Aggregate predictions (majority vote)
275
+ final_prediction = int(np.round(np.mean(fold_predictions)))
276
+ mean_probabilities = np.mean(fold_probabilities, axis=0)
277
+
278
+ return PredictionResult(
279
+ patient_id=patient_id,
280
+ task='binary',
281
+ predictions={'class': final_prediction, 'label': 'AD' if final_prediction else 'HC'},
282
+ probabilities={'HC': float(mean_probabilities[0]), 'AD': float(mean_probabilities[1])},
283
+ confidence=float(np.max(mean_probabilities)),
284
+ metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
285
+ )
286
+
287
+ def _predict_multiclass(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
288
+ """Make multiclass classification predictions - FIXED"""
289
+ if 'fusion' not in self.models['multiclass'] or not self.models['multiclass']['fusion']:
290
+ raise ValueError("No multiclass fusion model loaded")
291
+
292
+ fold_predictions = []
293
+ fold_probabilities = []
294
+ class_labels = ['HC', 'MCI', 'AD']
295
+
296
+ # Get patient ID
297
+ patient_id = 'unknown'
298
+ if 'filename' in features['chunk1'].columns:
299
+ patient_id = features['chunk1']['filename'].iloc[0]
300
+
301
+ for fold_num, model_pack in self.models['multiclass']['fusion'].items():
302
+ try:
303
+ # Prepare features for each model
304
+ model1 = model_pack['chunk1']
305
+ model2 = model_pack['chunk2']
306
+
307
+ # Get features that the models were trained on
308
+ X_chunk1 = features['chunk1']
309
+ X_chunk2 = features['chunk2']
310
+
311
+ # Ensure we have the right features
312
+ if hasattr(model1, 'feature_names_in_'):
313
+ model1_features = [f for f in model1.feature_names_in_ if f in X_chunk1.columns]
314
+ X_chunk1 = X_chunk1[model1_features]
315
+
316
+ if hasattr(model2, 'feature_names_in_'):
317
+ model2_features = [f for f in model2.feature_names_in_ if f in X_chunk2.columns]
318
+ X_chunk2 = X_chunk2[model2_features]
319
+
320
+ pred_proba_1 = model1.predict_proba(X_chunk1)
321
+ pred_proba_2 = model2.predict_proba(X_chunk2)
322
+
323
+ # Apply fusion weights
324
+ weights = model_pack.get('weights', [0.5, 0.5])
325
+ fusion_proba = weights[0] * pred_proba_1 + weights[1] * pred_proba_2
326
+
327
+ pred = np.argmax(fusion_proba, axis=1)
328
+
329
+ fold_predictions.append(pred[0])
330
+ fold_probabilities.append(fusion_proba[0])
331
+ except Exception as e:
332
+ self._log(f"Error in multiclass fold {fold_num}: {e}")
333
+ continue
334
+
335
+ if not fold_predictions:
336
+ raise ValueError("No successful multiclass predictions from any fold")
337
+
338
+ # Aggregate predictions
339
+ final_prediction = int(np.round(np.mean(fold_predictions)))
340
+ mean_probabilities = np.mean(fold_probabilities, axis=0)
341
+
342
+ prob_dict = {label: float(prob) for label, prob in zip(class_labels, mean_probabilities)}
343
+
344
+ return PredictionResult(
345
+ patient_id=patient_id,
346
+ task='multiclass',
347
+ predictions={'class': final_prediction, 'label': class_labels[final_prediction]},
348
+ probabilities=prob_dict,
349
+ confidence=float(np.max(mean_probabilities)),
350
+ metadata={'num_folds': len(fold_predictions)}
351
+ )
352
+
353
+ def _predict_regression(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
354
+ """Make regression predictions - FIXED"""
355
+ model_type = 'simple'
356
+
357
+ if model_type not in self.models['regression'] or not self.models['regression'][model_type]:
358
+ raise ValueError("No regression models loaded")
359
+
360
+ fold_predictions = []
361
+
362
+ # Get patient ID
363
+ patient_id = 'unknown'
364
+ if 'filename' in features['chunk_1_5'].columns:
365
+ patient_id = features['chunk_1_5']['filename'].iloc[0]
366
+
367
+ for fold_num, model_pack in self.models['regression'][model_type].items():
368
+ try:
369
+ # Simple fusion prediction
370
+ models = model_pack['models']
371
+ weights = model_pack['weights']
372
+
373
+ # Get predictions from each model with proper feature selection
374
+ model1 = models["ridge_1_5_minirocket"]
375
+ X1 = features['chunk_1_5']
376
+ if hasattr(model1, 'feature_names_in_'):
377
+ model1_features = [f for f in model1.feature_names_in_ if f in X1.columns]
378
+ X1 = X1[model1_features]
379
+ pred1 = model1.predict(X1)
380
+
381
+ model2 = models["ridge_3_5_minirocket"]
382
+ X2 = features['chunk_3_5']
383
+ if hasattr(model2, 'feature_names_in_'):
384
+ model2_features = [f for f in model2.feature_names_in_ if f in X2.columns]
385
+ X2 = X2[model2_features]
386
+ pred2 = model2.predict(X2)
387
+
388
+ model3 = models["ridge_4_5_categorical"]
389
+ X3 = features['chunk_4_5_demo']
390
+ if hasattr(model3, 'feature_names_in_'):
391
+ model3_features = [f for f in model3.feature_names_in_ if f in X3.columns]
392
+ X3 = X3[model3_features]
393
+ pred3 = model3.predict(X3)
394
+
395
+ final_pred = weights[0] * pred1 + weights[1] * pred2 + weights[2] * pred3
396
+ fold_predictions.append(final_pred[0])
397
+
398
+ except Exception as e:
399
+ self._log(f"Error in regression fold {fold_num}: {e}")
400
+ continue
401
+
402
+ if not fold_predictions:
403
+ raise ValueError("No successful regression predictions from any fold")
404
+
405
+ # Aggregate predictions
406
+ final_prediction = float(np.mean(fold_predictions))
407
+ std_prediction = float(np.std(fold_predictions))
408
+
409
+ return PredictionResult(
410
+ patient_id=patient_id,
411
+ task='regression',
412
+ predictions={'mmse_score': final_prediction, 'std': std_prediction},
413
+ confidence=1.0 / (1.0 + std_prediction),
414
+ metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
415
+ )
416
+
417
+ def _predict_regression(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
418
+ """Make regression predictions - FIXED"""
419
+ model_type = 'simple'
420
+
421
+ if model_type not in self.models['regression'] or not self.models['regression'][model_type]:
422
+ raise ValueError("No regression models loaded")
423
+
424
+ fold_predictions = []
425
+
426
+ # Get patient ID
427
+ patient_id = 'unknown'
428
+ if 'filename' in features['chunk_1_5'].columns:
429
+ patient_id = features['chunk_1_5']['filename'].iloc[0]
430
+
431
+ for fold_num, model_pack in self.models['regression'][model_type].items():
432
+ try:
433
+ # Simple fusion prediction
434
+ models = model_pack['models']
435
+ weights = model_pack['weights']
436
+
437
+ # Get predictions from each model with proper feature selection
438
+ model1 = models["ridge_1_5_minirocket"]
439
+ X1 = features['chunk_1_5']
440
+ if hasattr(model1, 'feature_names_in_'):
441
+ model1_features = [f for f in model1.feature_names_in_ if f in X1.columns]
442
+ X1 = X1[model1_features]
443
+ pred1 = model1.predict(X1)
444
+
445
+ model2 = models["ridge_3_5_minirocket"]
446
+ X2 = features['chunk_3_5']
447
+ if hasattr(model2, 'feature_names_in_'):
448
+ model2_features = [f for f in model2.feature_names_in_ if f in X2.columns]
449
+ X2 = X2[model2_features]
450
+ pred2 = model2.predict(X2)
451
+
452
+ model3 = models["ridge_4_5_categorical"]
453
+ X3 = features['chunk_4_5_demo']
454
+ if hasattr(model3, 'feature_names_in_'):
455
+ model3_features = [f for f in model3.feature_names_in_ if f in X3.columns]
456
+ X3 = X3[model3_features]
457
+ pred3 = model3.predict(X3)
458
+
459
+ final_pred = weights[0] * pred1 + weights[1] * pred2 + weights[2] * pred3
460
+ fold_predictions.append(final_pred[0])
461
+
462
+ except Exception as e:
463
+ self._log(f"Error in regression fold {fold_num}: {e}")
464
+ continue
465
+
466
+ if not fold_predictions:
467
+ raise ValueError("No successful regression predictions from any fold")
468
+
469
+ # Aggregate predictions
470
+ final_prediction = float(np.mean(fold_predictions))
471
+ std_prediction = float(np.std(fold_predictions))
472
+
473
+ return PredictionResult(
474
+ patient_id=patient_id,
475
+ task='regression',
476
+ predictions={'mmse_score': final_prediction, 'std': std_prediction},
477
+ confidence=1.0 / (1.0 + std_prediction),
478
+ metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
479
+ )
480
+
481
+ def predict(self, patient_data: Union[PatientData, List[PatientData]],
482
+ task: Optional[str] = None) -> Union[
483
+ PredictionResult, List[PredictionResult], Dict[str, PredictionResult]]:
484
+ """
485
+ Make predictions for one or more patients.
486
+
487
+ Args:
488
+ patient_data: Single PatientData or list of PatientData objects
489
+ task: Specific task ('binary', 'multiclass', 'regression') or None for all
490
+
491
+ Returns:
492
+ Single PredictionResult, list of results, or dict of results by task
493
+ """
494
+ # Handle single patient
495
+ if isinstance(patient_data, PatientData):
496
+ patient_data = [patient_data]
497
+ single_patient = True
498
+ else:
499
+ single_patient = False
500
+
501
+ results = []
502
+
503
+ for patient in patient_data:
504
+ patient_results = {}
505
+
506
+ tasks_to_run = [task] if task else ['binary', 'multiclass', 'regression']
507
+
508
+ for current_task in tasks_to_run:
509
+ try:
510
+ # Prepare features for the task
511
+ features = self._prepare_features(patient, current_task)
512
+
513
+ # Make predictions based on task
514
+ if current_task == 'binary':
515
+ result = self._predict_binary(features)
516
+ elif current_task == 'multiclass':
517
+ result = self._predict_multiclass(features)
518
+ elif current_task == 'regression':
519
+ result = self._predict_regression(features)
520
+ else:
521
+ raise ValueError(f"Unknown task: {current_task}")
522
+
523
+ # Fix: use patient from the loop, not patient_data
524
+ result.patient_id = patient.patient_id
525
+ patient_results[current_task] = result
526
+
527
+ except Exception as e:
528
+ self._log(f"Error predicting {current_task} for patient {patient.patient_id}: {e}")
529
+ patient_results[current_task] = PredictionResult(
530
+ patient_id=patient.patient_id,
531
+ task=current_task,
532
+ predictions={'error': str(e)},
533
+ metadata={'status': 'failed'}
534
+ )
535
+
536
+ # Return appropriate format
537
+ if task: # Single task
538
+ results.append(patient_results[task])
539
+ else: # All tasks
540
+ results.append(patient_results)
541
+
542
+ # Format return based on input
543
+ if single_patient:
544
+ return results[0]
545
+ else:
546
+ return results
example/example.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54ff5dd4723f32d833b0631cccb64b1422cb1bdf629c1dc00398a7f65b196c6e
3
+ size 57602320
model_weights/binary/model_fold_0.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edb9a91485fb5d7c7fe6fe9796b3acdf03dbef1842b97dcfec15e34145b816e1
3
+ size 4372336
model_weights/binary/model_fold_1.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed1aea5b008ae5f4ec22ecb4123332a54f6fa9ac94777a6c0c17f4934740c3eb
3
+ size 4408720
model_weights/binary/model_fold_2.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9cd344db067466560b22bbe6a0b3c919b1ab3f62af501aca734049d2242397b
3
+ size 4328032
model_weights/binary/model_fold_3.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f77fff9a656acc76215308264e06bcefe66feed764973dc6f266c804dd4cb8f
3
+ size 4408992
model_weights/binary/model_fold_4.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad60480b9c54c419f32d0fc2d2ee2146cf08c4930a9471394f09145e581825aa
3
+ size 4316144
model_weights/multiclass/model_fold_0.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78d1f25bdf7c00f3269f9e8617bb00e8b3c1a40807809f9e4dcbffd09cc589ff
3
+ size 3025389
model_weights/multiclass/model_fold_1.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18ebd09065c2064f4badb15b1613049a65bb1adff602344d38f21766166b7c29
3
+ size 3031085
model_weights/multiclass/model_fold_2.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bc18f007aafa81961d9bb7983d64cffe74e97920179d5f5d3d5252855daa552
3
+ size 3036701
model_weights/multiclass/model_fold_3.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e39448ac78a7bb0ac9033be9fb4110ea250575b0e2138b17f20740cf7977e31
3
+ size 3018797
model_weights/multiclass/model_fold_4.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebd6ebea08c3a8d1163298f07411c30fcaf2b66bda47d04366e1efa3bb1a49a6
3
+ size 3020637
model_weights/regression/model_fold_0.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adb800cd7a00422815621e359a2a268da5269598ad078ddb68c33094eea5d1bb
3
+ size 101857
model_weights/regression/model_fold_1.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a93d25b3c8159d09b7d63cf95016fbe0272469c08a08bb85deacbde6573b7fb
3
+ size 101857
model_weights/regression/model_fold_2.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07d48e469318ea8db8a2436208fa96f722706b66acf05766f23db5e5726e2877
3
+ size 101857
model_weights/regression/model_fold_3.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c3431945f467ab8709980eb4fcd3190acba70a2b7c244405cb264b24c0c48f9
3
+ size 101857
model_weights/regression/model_fold_4.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:049cfb0e17685c959697fc49ab4bb077a549b7b45a6425f41a6b3c8c02349fe8
3
+ size 101857
models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/__init__.py
2
+ from .inference_wav2vec import Wav2VecInference
3
+
4
+ __all__ = ["Wav2VecInference"]
5
+
6
+ # preprocessing/__init__.py
7
+ from preprocessing.flattening_base import BaseFlattening
8
+ from preprocessing.flattening_categorical import CategoricalFlattening
9
+ from preprocessing.flattening_minirocket import MiniRocketFlattening
10
+ from preprocessing.flattening_statistical import StatisticalFlattening
11
+
12
+ __all__ = [
13
+ "BaseFlattening",
14
+ "CategoricalFlattening",
15
+ "MiniRocketFlattening",
16
+ "StatisticalFlattening"
17
+ ]
models/inference_wav2vec.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from typing import List
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from pydub import AudioSegment
8
+ from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
9
+
10
+ from utils.config_types import ChunkLength
11
+
12
+ # Suppress specific warnings
13
+ warnings.filterwarnings('ignore', message='Passing `gradient_checkpointing` to a config initialization')
14
+ warnings.filterwarnings('ignore', message='Some weights of the model checkpoint')
15
+ warnings.filterwarnings('ignore', message='Some weights of Wav2Vec2ForSequenceClassification')
16
+
17
+
18
+ class Wav2VecInference:
19
+ """Implementation for wav2vec emotion recognition model."""
20
+ name = "wav2vec"
21
+
22
+ def __init__(self, config, verbose=False):
23
+ self.config = config
24
+ self.verbose = verbose
25
+ self.input_directory = self.config.get_full_path(self.config.paths['data'].raw)
26
+ self.output_path = self.config.get_full_path(self.config.paths['data'].processed)
27
+ self.id2label = {
28
+ "0": "angry", "1": "calm", "2": "disgust", "3": "fearful",
29
+ "4": "happy", "5": "neutral", "6": "sad", "7": "surprised"
30
+ }
31
+ self.model = None
32
+ self.feature_extractor = None
33
+
34
+ def get_emotion_labels(self) -> List[str]:
35
+ """Return a list of emotion labels used by this model."""
36
+ return list(self.id2label.values())
37
+
38
+ def load_model(self):
39
+ """Load the wav2vec model and feature extractor."""
40
+ if self.model is None or self.feature_extractor is None:
41
+ if self.verbose:
42
+ print("Loading wav2vec2 emotion recognition model...")
43
+
44
+ # Suppress warnings during model loading
45
+ with warnings.catch_warnings():
46
+ warnings.filterwarnings("ignore")
47
+
48
+ # Load the model without the gradient_checkpointing parameter
49
+ self.model = AutoModelForAudioClassification.from_pretrained(
50
+ "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition",
51
+ ignore_mismatched_sizes=True # This helps with the weights warning
52
+ ).to(self.config.settings.device)
53
+
54
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
55
+ "facebook/wav2vec2-large-xlsr-53"
56
+ )
57
+
58
+ if self.verbose:
59
+ print("Model loaded successfully.")
60
+
61
+ def _predict_emotion_from_segment(self, segment):
62
+ """Predict emotions for a single audio segment."""
63
+ # Convert to mono numpy array
64
+ waveform = np.array(segment.get_array_of_samples()).astype(np.float32)
65
+ if segment.channels == 2:
66
+ waveform = waveform.reshape((-1, 2)).mean(axis=1)
67
+ waveform = waveform.reshape(-1)
68
+
69
+ # Extract features
70
+ inputs = self.feature_extractor(
71
+ waveform,
72
+ sampling_rate=segment.frame_rate,
73
+ return_tensors="pt",
74
+ padding=True
75
+ )
76
+ inputs = {k: v.to(self.config.settings.device) for k, v in inputs.items()}
77
+
78
+ # Get predictions
79
+ with torch.no_grad():
80
+ logits = self.model(**inputs).logits
81
+ probabilities = torch.softmax(logits, dim=1).cpu().squeeze().tolist()
82
+
83
+ return {self.id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities)}
84
+
85
+ def analyze_emotions_over_time(self, audio_file: str, segment_duration: int,
86
+ overlap_duration: int) -> list:
87
+ """
88
+ Analyze emotions in chunks over the duration of an audio file.
89
+
90
+ Args:
91
+ audio_file: Path to the audio file
92
+ segment_duration: Duration of each segment in milliseconds
93
+ overlap_duration: Overlap between segments in milliseconds
94
+
95
+ Returns:
96
+ List of tuples (start_time, end_time, emotion_probabilities)
97
+ """
98
+ sound = AudioSegment.from_file(audio_file)
99
+ duration = len(sound)
100
+ emotions_over_time = []
101
+
102
+ start = 0
103
+ total_segments = 0
104
+
105
+ # Calculate total number of segments for progress tracking
106
+ if self.verbose:
107
+ temp_start = 0
108
+ while temp_start + segment_duration <= duration:
109
+ total_segments += 1
110
+ temp_start += segment_duration - overlap_duration
111
+ print(f"Processing {total_segments} segments from audio file...")
112
+
113
+ segment_count = 0
114
+ while start + segment_duration <= duration:
115
+ segment = sound[start:start + segment_duration]
116
+ emotion_probabilities = self._predict_emotion_from_segment(segment)
117
+ emotions_over_time.append((start, start + segment_duration, emotion_probabilities))
118
+ start += segment_duration - overlap_duration
119
+
120
+ segment_count += 1
121
+ if self.verbose and segment_count % 10 == 0:
122
+ print(f" Processed {segment_count}/{total_segments} segments...")
123
+
124
+ if self.verbose:
125
+ print(f" Completed processing {segment_count} segments.")
126
+
127
+ return emotions_over_time
preprocessing/flattening_base.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import pandas as pd
3
+
4
+
5
+ class BaseFlattening(ABC):
6
+ name = "base"
7
+
8
+ @abstractmethod
9
+ def flatten_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
10
+ pass
11
+
12
+ def get_name(self) -> str:
13
+ return self.name
14
+
15
+ def __str__(self):
16
+ return self.name
preprocessing/flattening_categorical.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sklearn.preprocessing import LabelEncoder
3
+
4
+ from preprocessing.flattening_base import BaseFlattening
5
+
6
+
7
+ class CategoricalFlattening(BaseFlattening):
8
+ def __init__(self):
9
+ """Initialize the categorical flattening strategy."""
10
+ self.name = "categorical"
11
+ self.label_encoder = LabelEncoder()
12
+ # Maximum number of positions to encode individually
13
+ self.max_positions = 1000
14
+
15
+ def flatten_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
16
+ """
17
+ Implement categorical flattening for the combined DataFrame.
18
+ For each row, identifies the most dominant emotion and creates
19
+ a fixed-length sequence of these dominant emotions for each patient.
20
+ Additionally, adds summary features to capture information from longer sequences.
21
+
22
+ Args:
23
+ df: DataFrame containing all inference results
24
+
25
+ Returns:
26
+ DataFrame containing categorical features with encoded emotions
27
+ """
28
+ # Group by filename (patient ID)
29
+ grouped = df.groupby('filename')
30
+ results = []
31
+
32
+ emotion_columns = [col for col in df.columns
33
+ if col not in ['start', 'end', 'filename']]
34
+
35
+ # Calculate global sequence statistics to determine max positions
36
+ sequence_lengths = []
37
+ for _, group in grouped:
38
+ sequence_lengths.append(len(group))
39
+
40
+ # Get the 95th percentile of sequence lengths to determine max_positions
41
+ if sequence_lengths:
42
+ # Only adjust max_positions if we have enough data
43
+ if len(sequence_lengths) > 10:
44
+ p95_length = int(pd.Series(sequence_lengths).quantile(0.95))
45
+ # Limit to a reasonable maximum (30) to prevent too many columns
46
+ self.max_positions = min(p95_length, self.max_positions)
47
+
48
+ for filename, group in grouped:
49
+ # Get the dominant emotion for each time point
50
+ dominant_emotions = group[emotion_columns].idxmax(axis=1)
51
+
52
+ # Create features dictionary for this patient
53
+ features = {'filename': filename}
54
+
55
+ # Add sequence length as a feature
56
+ features['sequence_length'] = len(dominant_emotions)
57
+
58
+ # Add fixed-length position features (up to max_positions)
59
+ for i in range(1, self.max_positions + 1):
60
+ if i <= len(dominant_emotions):
61
+ features[f'emotion_pos_{i}'] = dominant_emotions.iloc[i - 1]
62
+ else:
63
+ # For shorter sequences, use a consistent padding value
64
+ features[f'emotion_pos_{i}'] = "padding"
65
+
66
+ # Add summary features for the entire sequence
67
+ for emotion in emotion_columns:
68
+ # Count occurrences of each emotion
69
+ features[f'count_{emotion}'] = (dominant_emotions == emotion).sum()
70
+ # Calculate proportion of each emotion
71
+ features[f'prop_{emotion}'] = features[f'count_{emotion}'] / len(dominant_emotions) if len(
72
+ dominant_emotions) > 0 else 0
73
+
74
+ # Add summary of emotion transitions
75
+ if len(dominant_emotions) > 1:
76
+ transitions = 0
77
+ for i in range(len(dominant_emotions) - 1):
78
+ if dominant_emotions.iloc[i] != dominant_emotions.iloc[i + 1]:
79
+ transitions += 1
80
+ features['emotion_transitions'] = transitions
81
+ features['emotion_transitions_ratio'] = transitions / (len(dominant_emotions) - 1)
82
+ else:
83
+ features['emotion_transitions'] = 0
84
+ features['emotion_transitions_ratio'] = 0
85
+
86
+ results.append(features)
87
+
88
+ # Create DataFrame from results
89
+ result_df = pd.DataFrame(results)
90
+
91
+ # Encode all emotion position columns
92
+ emotion_cols = [col for col in result_df.columns if col.startswith('emotion_pos_')]
93
+ for col in emotion_cols:
94
+ result_df[col] = self.label_encoder.fit_transform(result_df[col].astype(str))
95
+
96
+ return result_df
97
+
98
+ def get_emotion_mapping(self) -> dict:
99
+ """
100
+ Get the mapping between encoded values and emotion categories.
101
+
102
+ Returns:
103
+ Dictionary mapping encoded values to emotion categories
104
+ """
105
+ return dict(zip(
106
+ self.label_encoder.transform(self.label_encoder.classes_),
107
+ self.label_encoder.classes_
108
+ ))
preprocessing/flattening_minirocket.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from sklearn.preprocessing import StandardScaler
4
+ from sktime.transformations.panel.rocket import MiniRocket
5
+
6
+ from preprocessing.flattening_base import BaseFlattening
7
+
8
+ class MiniRocketFlattening(BaseFlattening):
9
+ def __init__(self):
10
+ """Initialize the MiniRocket flattening strategy."""
11
+ self.name = "minirocket"
12
+ self.minirocket = None
13
+ self.emotion_columns = None
14
+ self.scaler = StandardScaler()
15
+ self.min_sequence_length = 9 # MiniRocket requires at least 9 timepoints
16
+ self.features_per_emotion = None # Will store number of features per emotion
17
+
18
+ def _initialize_minirocket(self, data_3d):
19
+ """
20
+ Initialize and fit the MiniRocket transformer.
21
+
22
+ Args:
23
+ data_3d: 3D array with shape (n_instances, n_columns, n_timepoints)
24
+ """
25
+ try:
26
+ self.minirocket = MiniRocket(
27
+ random_state=42,
28
+ n_jobs=1,
29
+ num_kernels=84,
30
+ max_dilations_per_kernel=32
31
+ )
32
+ # print(f"Initializing MiniRocket with data shape: {data_3d.shape}")
33
+ self.minirocket.fit(data_3d)
34
+ # Calculate features per emotion after first transform
35
+ sample_transform = self.minirocket.transform(data_3d)
36
+ total_features = sample_transform.shape[1]
37
+ self.features_per_emotion = total_features // len(self.emotion_columns)
38
+
39
+ except Exception as e:
40
+ print(f"Error initializing MiniRocket: {str(e)}")
41
+ raise
42
+
43
+ def _prepare_3d_array(self, group_data: pd.DataFrame) -> np.ndarray:
44
+ """
45
+ Convert patient group data to 3D array format required by MiniRocket with improved error handling.
46
+
47
+ Args:
48
+ group_data: DataFrame containing one patient's data
49
+
50
+ Returns:
51
+ 3D numpy array with shape (n_instances, n_columns, n_timepoints)
52
+ """
53
+ try:
54
+ if self.emotion_columns is None:
55
+ self.emotion_columns = [col for col in group_data.columns
56
+ if col not in ['start', 'end', 'filename']]
57
+
58
+ # Make sure we have data
59
+ if group_data.empty:
60
+ raise ValueError("Empty data array")
61
+
62
+ data = group_data[self.emotion_columns].values
63
+ if data.size == 0:
64
+ raise ValueError("Empty data array after selecting columns")
65
+
66
+ # Handle sequence length, ensuring minimum of 9 timepoints
67
+ current_length = data.shape[0]
68
+ if current_length < self.min_sequence_length:
69
+ # Padding for sequences shorter than minimum
70
+ pad_length = self.min_sequence_length - current_length
71
+ data = np.pad(data, ((0, pad_length), (0, 0)), mode='constant')
72
+ elif current_length > self.min_sequence_length:
73
+ # Truncate longer sequences
74
+ data = data[:self.min_sequence_length]
75
+
76
+ # Normalize the data with safeguards against division by zero
77
+ data_mean = data.mean(axis=0)
78
+ data_std = data.std(axis=0)
79
+ # Add small epsilon to avoid division by zero
80
+ data_std = np.where(data_std < 1e-8, 1e-8, data_std)
81
+ data = (data - data_mean) / data_std
82
+ data = data.astype(np.float32)
83
+
84
+ # Create 3D array
85
+ data_3d = np.zeros((1, len(self.emotion_columns), data.shape[0]), dtype=np.float32)
86
+ for i in range(len(self.emotion_columns)):
87
+ data_3d[0, i, :] = data[:, i]
88
+
89
+ return data_3d
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error in _prepare_3d_array: {str(e)}")
93
+ # Return a safe default array if processing fails
94
+ empty_array = np.zeros((1, len(self.emotion_columns) if self.emotion_columns else 1,
95
+ self.min_sequence_length), dtype=np.float32)
96
+ return empty_array
97
+
98
+ def flatten_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
99
+ """Implement MiniRocket flattening with emotion-prefixed features."""
100
+ grouped = df.groupby('filename')
101
+ all_features = []
102
+ first_batch = True
103
+
104
+ # Process each patient
105
+ for filename, group in grouped:
106
+ try:
107
+ # Prepare 3D array for this patient
108
+ data_3d = self._prepare_3d_array(group)
109
+
110
+ # Initialize MiniRocket with first batch
111
+ if first_batch:
112
+ self._initialize_minirocket(data_3d)
113
+ first_batch = False
114
+
115
+ # Transform the data
116
+ features = self.minirocket.transform(data_3d)
117
+ features_array = features.to_numpy()
118
+ if len(features_array.shape) > 1:
119
+ features_array = features_array[0]
120
+
121
+ # Create features dictionary with emotion prefixes
122
+ features_dict = {'filename': filename}
123
+
124
+ # Distribute features among emotions
125
+ for emotion_idx, emotion in enumerate(self.emotion_columns):
126
+ start_idx = emotion_idx * self.features_per_emotion
127
+ end_idx = start_idx + self.features_per_emotion
128
+
129
+ # Add features for this emotion
130
+ emotion_features = {
131
+ f'{emotion}_feature_{i}': value
132
+ for i, value in enumerate(features_array[start_idx:end_idx])
133
+ }
134
+ features_dict.update(emotion_features)
135
+
136
+ all_features.append(features_dict)
137
+
138
+ except Exception as e:
139
+ print(f"Error processing patient {filename}: {str(e)}")
140
+ continue
141
+
142
+ # Create DataFrame from all features
143
+ result_df = pd.DataFrame(all_features)
144
+
145
+ # Scale the features if we have any
146
+ feature_cols = [col for col in result_df.columns if 'feature_' in col]
147
+ if feature_cols:
148
+ features_array = result_df[feature_cols].values
149
+ scaled_features = self.scaler.fit_transform(features_array)
150
+
151
+ # Update the DataFrame with scaled features
152
+ for i, col in enumerate(feature_cols):
153
+ result_df[col] = scaled_features[:, i]
154
+
155
+ return result_df
preprocessing/flattening_statistical.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from preprocessing.flattening_base import BaseFlattening
3
+
4
+
5
+ class StatisticalFlattening(BaseFlattening):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.name = "statistical"
9
+
10
+ def flatten_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
11
+ """
12
+ Implement statistical flattening for the combined DataFrame.
13
+
14
+ Args:
15
+ df: DataFrame containing all inference results
16
+
17
+ Returns:
18
+ DataFrame containing statistical features for each unique filename
19
+ """
20
+ grouped = df.groupby('filename')
21
+
22
+ results = []
23
+ for filename, group in grouped:
24
+ stats = {'filename': filename}
25
+
26
+ # Calculate statistics for each feature column
27
+ for column in group.columns:
28
+ if column in ['start', 'end', 'filename']:
29
+ continue
30
+
31
+ # Calculate statistical features
32
+ stats[f'{column}_mean'] = group[column].mean()
33
+ stats[f'{column}_std'] = group[column].std()
34
+ stats[f'{column}_max'] = group[column].max()
35
+ stats[f'{column}_min'] = group[column].min()
36
+ stats[f'{column}_mode'] = group[column].mode().iloc[0]
37
+ stats[f'{column}_skewness'] = group[column].skew()
38
+ stats[f'{column}_kurtosis'] = group[column].kurtosis()
39
+ stats[f'{column}_median'] = group[column].median()
40
+ stats[f'{column}_q1'] = group[column].quantile(0.25)
41
+ stats[f'{column}_q3'] = group[column].quantile(0.75)
42
+ stats[f'{column}_iqr'] = group[column].quantile(0.75) - group[column].quantile(0.25)
43
+ stats[f'{column}_range'] = group[column].max() - group[column].min()
44
+ stats[f'{column}_variance'] = group[column].var()
45
+ stats[f'{column}_sem'] = group[column].sem()
46
+ stats[f'{column}_cv'] = group[column].std() / group[column].mean()
47
+
48
+ results.append(stats)
49
+
50
+ return pd.DataFrame(results)
utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # utils/__init__.py
2
+ from utils.config import ProjectConfig
3
+ from utils.config_types import ChunkLength
4
+
5
+ __all__ = ["ProjectConfig", "ChunkLength"]
utils/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from .config_types import (
4
+ MachinePaths, Labels, ModelPaths, Settings,
5
+ Data, ChunkLength, ChunkConfig
6
+ )
7
+
8
+
9
+ class ProjectConfig:
10
+ """Project configuration class."""
11
+
12
+ def __init__(self):
13
+ self.machine_paths = MachinePaths()
14
+ self.paths = {
15
+ 'labels': Labels(),
16
+ 'models': ModelPaths(),
17
+ 'data': Data()
18
+ }
19
+ self.settings = Settings()
20
+
21
+ def get_full_path(self, relative_path: str) -> str:
22
+ """Combine machine base path with relative path"""
23
+ base_path = self.machine_paths.get_current_machine_path()
24
+ return os.path.join(base_path, relative_path)
25
+
26
+ def get_model_length_path(self, model_name: str, length: ChunkLength) -> str:
27
+ """Get the folder location for a specific model and length"""
28
+ base_path = self.machine_paths.get_current_machine_path()
29
+ folder_name = f"{model_name}-{str(length)}"
30
+ return os.path.join(base_path, 'processed', folder_name)
31
+
32
+ def get_chunk_params(self, length: ChunkLength) -> ChunkConfig:
33
+ """Get chunk configuration for a specific length."""
34
+ if length not in self.paths['data'].chunks:
35
+ raise ValueError(f"Invalid length {length}. Must be one of {list(self.paths['data'].chunks.keys())}")
36
+ return self.paths['data'].chunks[length]
37
+
38
+ def get_available_lengths(self) -> List[ChunkLength]:
39
+ """Get list of available chunk lengths."""
40
+ return list(self.paths['data'].chunks.keys())
41
+
42
+ def get_model_path(self, name, length, task):
43
+ pass
utils/config_types.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Dict
6
+ import torch
7
+
8
+
9
+ class ChunkLength(str, Enum):
10
+ """Enumeration of available chunk lengths."""
11
+ LENGTH_1_5 = "1_5"
12
+ LENGTH_3_5 = "3_5"
13
+ LENGTH_4_5 = "4_5"
14
+
15
+ def __str__(self) -> str:
16
+ return self.value
17
+
18
+
19
+ @dataclass
20
+ class ChunkConfig:
21
+ """Configuration for audio chunk processing."""
22
+ length: ChunkLength
23
+ segment_duration: int # in milliseconds
24
+ overlap_duration: int # in milliseconds
25
+
26
+
27
+ @dataclass
28
+ class MachinePaths:
29
+ """Base paths for different machines/environments."""
30
+ default: Path = Path('./data')
31
+
32
+ def get_current_machine_path(self) -> Path:
33
+ """Get base path for current machine based on environment variable."""
34
+ # Use EMOTION_DATA_PATH environment variable if set, otherwise use default
35
+ env_path = os.getenv('EMOTION_DATA_PATH')
36
+ if env_path:
37
+ return Path(env_path)
38
+ return self.default
39
+
40
+
41
+ @dataclass
42
+ class Labels:
43
+ """Paths to label files."""
44
+ demographic: str = 'processed/labels/demographic.csv'
45
+ multi: str = 'processed/labels/multiclass.csv'
46
+ regression: str = 'processed/labels/regression.csv'
47
+
48
+
49
+ @dataclass
50
+ class ModelPaths:
51
+ """Paths for model storage."""
52
+ base_dir: str = 'models/'
53
+ checkpoints: str = 'models/checkpoints/'
54
+ configs: str = 'models/configs/'
55
+
56
+
57
+ @dataclass
58
+ class Settings:
59
+ """General framework settings."""
60
+ seed: int = 42
61
+ batch_size: int = 16
62
+ num_workers: int = -1
63
+ device: str = None
64
+
65
+ def __post_init__(self):
66
+ if self.device is None:
67
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
68
+
69
+
70
+ @dataclass
71
+ class Data:
72
+ """Data paths and chunk configurations."""
73
+ raw: str = 'raw/'
74
+ processed: str = 'processed/inference/'
75
+ table: str = 'processed/tables/'
76
+ splits: str = 'processed/splits/'
77
+ results: str = 'results/'
78
+ chunks: Dict[ChunkLength, ChunkConfig] = None
79
+
80
+ def __post_init__(self):
81
+ if self.chunks is None:
82
+ self.chunks = {
83
+ ChunkLength.LENGTH_1_5: ChunkConfig(
84
+ length=ChunkLength.LENGTH_1_5,
85
+ segment_duration=1500,
86
+ overlap_duration=500
87
+ ),
88
+ ChunkLength.LENGTH_3_5: ChunkConfig(
89
+ length=ChunkLength.LENGTH_3_5,
90
+ segment_duration=3500,
91
+ overlap_duration=500
92
+ ),
93
+ ChunkLength.LENGTH_4_5: ChunkConfig(
94
+ length=ChunkLength.LENGTH_4_5,
95
+ segment_duration=4500,
96
+ overlap_duration=500
97
+ )
98
+ }
utils/logger.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ # Configure logging
4
+ # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
5
+ logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
6
+ logger = logging.getLogger(__name__)
utils/tabular_transformation.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import pandas as pd
3
+ from sklearn.model_selection import StratifiedKFold, KFold
4
+
5
+ from utils.config_types import LearningTask
6
+ from utils.logger import logger
7
+ from preprocessing.flattening_base import BaseFlattening
8
+
9
+
10
+ class TabularTransformer:
11
+ """Static class for transforming tabular data in various ways."""
12
+
13
+ @staticmethod
14
+ def get_id(filename: str) -> int:
15
+ id_value = filename.split('_')[-1].split('-')[0]
16
+ return int(id_value)
17
+
18
+ @staticmethod
19
+ def attach_demographic(df_data: pd.DataFrame, demographic_path: str) -> pd.DataFrame:
20
+ """
21
+ Attaches demographic information to the data DataFrame.
22
+
23
+ Args:
24
+ df_data: DataFrame containing the processed data
25
+ demographic_path: Path to the demographic CSV file
26
+
27
+ Returns:
28
+ DataFrame with demographic information attached
29
+ """
30
+ df_demographic = pd.read_csv(demographic_path)
31
+
32
+ df_data['id'] = df_data['filename'].apply(TabularTransformer.get_id)
33
+ merged_df = pd.merge(df_data, df_demographic, on='id')
34
+ merged_df.drop(columns=['id'], inplace=True)
35
+ return merged_df
36
+
37
+ @staticmethod
38
+ def transform_to_binary(df: pd.DataFrame) -> pd.DataFrame:
39
+ """
40
+ Transforms the DataFrame into a binary classification problem.
41
+
42
+ Args:
43
+ df: DataFrame with demographic information
44
+
45
+ Returns:
46
+ DataFrame with binary target labels (0 or 1)
47
+ """
48
+ df['target'] = df['filename'].apply(lambda x: 1 if x.split('_')[0] == 'dementia' else 0)
49
+ return df
50
+
51
+ @staticmethod
52
+ def transform_to_multiclass(df: pd.DataFrame, path_to_label: str) -> pd.DataFrame:
53
+ """
54
+ Transforms the DataFrame into a multiclass classification problem.
55
+
56
+ Args:
57
+ df: DataFrame with demographic information
58
+ path_to_label: Path to the multiclass labels CSV file
59
+
60
+ Returns:
61
+ DataFrame with multiclass target labels
62
+ """
63
+ df_labels = pd.read_csv(path_to_label)
64
+ df_labels = df_labels[['ID', 'target_multi']]
65
+ df_labels.rename(columns={'target_multi': 'target'}, inplace=True)
66
+ df['ID'] = df['filename']
67
+ df = pd.merge(df, df_labels, on='ID')
68
+ df.drop(columns=['ID'], inplace=True)
69
+ return df
70
+
71
+ @staticmethod
72
+ def transform_to_regression(df: pd.DataFrame, path_to_label: str) -> pd.DataFrame:
73
+ """
74
+ Transforms the DataFrame into a regression problem.
75
+
76
+ Args:
77
+ df: DataFrame with demographic information
78
+ path_to_label: Path to the regression labels CSV file
79
+
80
+ Returns:
81
+ DataFrame with regression target labels
82
+ """
83
+ df_labels = pd.read_csv(path_to_label)
84
+ df_labels.rename(columns={'mms': 'target'}, inplace=True)
85
+
86
+ df['id'] = df['filename'].apply(TabularTransformer.get_id)
87
+ df = pd.merge(df, df_labels, on='id')
88
+ df.drop(columns=['id'], inplace=True)
89
+ return df
90
+
91
+ @staticmethod
92
+ def process_and_transform(df: pd.DataFrame,
93
+ transform_func: callable,
94
+ label_path_func: Optional[callable]) -> pd.DataFrame:
95
+ """Process and transform data based on the task requirements."""
96
+ if label_path_func:
97
+ return transform_func(df.copy(), label_path_func())
98
+ return transform_func(df.copy())
99
+
100
+ @staticmethod
101
+ def get_flattened_data(strategy: BaseFlattening, file_path: str, demographic_path: Optional[str] = None) -> pd.DataFrame:
102
+ """Get flattened data with optional demographic information."""
103
+ logger.info(f"Flattening data with strategy: {strategy.get_name()}")
104
+ # Flatten the data
105
+ df_inference = pd.read_csv(file_path) # Read the inference data
106
+ df_flatten = strategy.flatten_dataframe(df_inference) # Invoke the flattening strategy method
107
+
108
+ # Check if I want to attach demographic information
109
+ if demographic_path:
110
+ return TabularTransformer.attach_demographic(df_flatten, demographic_path)
111
+
112
+ # Return the flattened data
113
+ return df_flatten
114
+
115
+ @staticmethod
116
+ def get_memory_usage(df: pd.DataFrame) -> str:
117
+ """Get memory usage of dataframe in a human-readable format."""
118
+ memory_usage = df.memory_usage(deep=True).sum()
119
+ if memory_usage < 1024:
120
+ return f"{memory_usage} bytes"
121
+ elif memory_usage < 1024 ** 2:
122
+ return f"{memory_usage / 1024:.2f} KB"
123
+ elif memory_usage < 1024 ** 3:
124
+ return f"{memory_usage / 1024 ** 2:.2f} MB"
125
+ else:
126
+ return f"{memory_usage / 1024 ** 3:.2f} GB"
127
+
128
+ @staticmethod
129
+ def generate_folds(inference_data, folds, task, tasks):
130
+ df_folds = inference_data.groupby('filename').agg(list).reset_index()
131
+
132
+ # Check if the task is binary, multiclass, or regression
133
+ df_folds = df_folds[['filename']]
134
+
135
+ # Attach labels
136
+ df_task = TabularTransformer.process_and_transform(
137
+ df_folds,
138
+ tasks[task][0],
139
+ tasks[task][2]
140
+ )
141
+
142
+ # Split into features and target
143
+ target_column = df_task['target']
144
+ file_column = df_task['filename']
145
+
146
+ # Create the folds
147
+ fold_object = None
148
+ if task == LearningTask.BINARY or task == LearningTask.MULTICLASS:
149
+ fold_object = StratifiedKFold(n_splits=folds, shuffle=True, random_state=42)
150
+ elif task == LearningTask.REGRESSION:
151
+ fold_object = KFold(n_splits=folds, shuffle=True, random_state=42)
152
+
153
+ folds_data = []
154
+ for fold, (train_index, test_index) in enumerate(fold_object.split(file_column, target_column)):
155
+ current_fold = {
156
+ 'fold': fold,
157
+ 'train': file_column.iloc[train_index].tolist(),
158
+ 'test': file_column.iloc[test_index].tolist(),
159
+ }
160
+
161
+ folds_data.append(current_fold)
162
+
163
+ # Turn the folds into a DataFrame
164
+ fold_df = pd.DataFrame(folds_data, columns=['fold', 'train', 'test'])
165
+ return fold_df