CuriousMonkey7 commited on
Commit
8715f24
·
verified ·
1 Parent(s): 97a86a0

remove compute_performance.py

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. compute_performance.py +0 -99
.gitignore CHANGED
@@ -1 +1,2 @@
1
- *.feather
 
 
1
+ *.feather
2
+ *.py
compute_performance.py DELETED
@@ -1,99 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- from sklearn.metrics import roc_auc_score
4
-
5
- import pandas as pd
6
- from concurrent.futures import ProcessPoolExecutor
7
- from tqdm import tqdm
8
- from silero_vad import read_audio, get_speech_timestamps
9
- from functools import partial
10
-
11
- from pathlib import Path
12
- str(Path().resolve() / "silero-vad/src/silero-vad")
13
- from silero_vad import utils_vad
14
- # from utils_vad import init_jit_model, OnnxWrapper
15
- import torch
16
- torch.set_num_threads(1)
17
-
18
- def load_silero_vad(onnx=False, model_file_path=None):
19
- if onnx:
20
- model = utils_vad.OnnxWrapper(model_file_path, force_onnx_cpu=True)
21
- else:
22
- model = utils_vad.init_jit_model(model_file_path)
23
-
24
- return model
25
-
26
- def init_worker(model_file_path):
27
- """Initialize the model inside each worker process"""
28
- global model
29
- model = load_silero_vad(onnx=False, model_file_path=model_file_path)
30
-
31
-
32
-
33
- def get_vad(file, threshold):
34
- if pd.isna(file):
35
- return None
36
- wav = read_audio(file)
37
- speech_timestamps = get_speech_timestamps(
38
- wav,
39
- model,
40
- return_seconds=True,
41
- threshold=threshold
42
- )
43
- return speech_timestamps
44
-
45
- def process_vad_parallel(df, threshold, column_name, model_file_path):
46
- results = []
47
- with ProcessPoolExecutor(max_workers=8, initializer=partial(init_worker, model_file_path) ) as executor:
48
- futures = {executor.submit(get_vad, file, threshold): i for i, file in enumerate(df["audio_path"])}
49
- for future in tqdm(futures, total=len(df), desc=f"Processing {column_name}"):
50
- results.append(future.result()) # Collect results
51
- df[column_name] = results
52
- return df
53
-
54
-
55
-
56
- def create_frame_labels(segments, duration, frame_size=0.01):
57
- frames = np.zeros(int(duration / frame_size))
58
- for seg in segments:
59
- start_idx = int(seg['start'] / frame_size)
60
- end_idx = int(seg['end'] / frame_size)
61
- frames[start_idx:end_idx] = 1
62
- return frames
63
-
64
- def compute_auc_roc(df, actual_col, predicted_col, frame_size=0.01):
65
- max_time = max(
66
- max(seg['end'] for row in df[actual_col] for seg in row),
67
- max(seg['end'] for row in df[predicted_col] for seg in row)
68
- )
69
-
70
- gt_labels = create_frame_labels([seg for row in df[actual_col] for seg in row], max_time, frame_size)
71
- pred_labels = create_frame_labels([seg for row in df[predicted_col] for seg in row], max_time, frame_size)
72
-
73
- auc_roc = roc_auc_score(gt_labels, pred_labels)
74
- return auc_roc
75
-
76
-
77
-
78
- df = pd.read_feather("./val.feather")
79
- model_file_path = "/home/sourabh/Desktop/dev/hum-vad/HumAware-VAD/humaware_vad.jit"
80
- df = process_vad_parallel(df, 0.5, "unhum_vad_output_0.5", model_file_path=model_file_path)
81
- df = process_vad_parallel(df, 0.9, "unhum_vad_output_0.9", model_file_path=model_file_path)
82
-
83
- model_file_path = "/home/sourabh/Desktop/dev/hum-vad/.venv/lib/python3.12/site-packages/silero_vad/data/silero_vad.jit"
84
- df = process_vad_parallel(df, 0.5, "silero_vad_output_0.5", model_file_path=model_file_path)
85
- df = process_vad_parallel(df, 0.9, "silero_vad_output_0.9", model_file_path=model_file_path)
86
-
87
-
88
-
89
- auc_roc_score = compute_auc_roc(df, "speech_ts", "unhum_vad_output_0.5")
90
- print(f"AUC-ROC Score: {auc_roc_score:.4f}")
91
-
92
- auc_roc_score = compute_auc_roc(df, "speech_ts", "unhum_vad_output_0.9")
93
- print(f"AUC-ROC Score unhum_vad_output_0.9: {auc_roc_score:.4f}")
94
-
95
- auc_roc_score = compute_auc_roc(df, "speech_ts", "silero_vad_output_0.5")
96
- print(f"AUC-ROC Score silero_vad_output_0.5: {auc_roc_score:.4f}")
97
-
98
- auc_roc_score = compute_auc_roc(df, "speech_ts", "silero_vad_output_0.9")
99
- print(f"AUC-ROC Score silero_vad_output_0.9: {auc_roc_score:.4f}")