Spaces:
Sleeping
Sleeping
Update tasks/audio.py
Browse files- tasks/audio.py +31 -12
tasks/audio.py
CHANGED
@@ -5,6 +5,9 @@ from sklearn.metrics import accuracy_score
|
|
5 |
import random
|
6 |
import os
|
7 |
import torch
|
|
|
|
|
|
|
8 |
|
9 |
from .utils.evaluation import AudioEvaluationRequest
|
10 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
@@ -14,7 +17,7 @@ load_dotenv()
|
|
14 |
|
15 |
router = APIRouter()
|
16 |
|
17 |
-
DESCRIPTION = "
|
18 |
ROUTE = "/audio"
|
19 |
|
20 |
|
@@ -55,30 +58,46 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
55 |
#--------------------------------------------------------------------------------------------
|
56 |
|
57 |
# Make random predictions (placeholder for actual model inference)
|
58 |
-
def preprocess_audio(example):
|
59 |
-
"""Convert dataset into tensors."""
|
60 |
-
waveform = torch.tensor(example["audio"]["array"], dtype=torch.float32).unsqueeze(0) # Add batch dim
|
61 |
-
label = torch.tensor(example["label"], dtype=torch.long) # Ensure labels are `int64`
|
62 |
-
return waveform, label
|
63 |
-
|
64 |
model_path = "quantized_teacher_m5_static.pth"
|
65 |
model, device = load_model(model_path)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
train_test = train_test.map(preprocess_audio)
|
69 |
test_dataset = train_test.map(preprocess_audio)
|
70 |
|
71 |
-
|
72 |
|
73 |
|
74 |
-
true_labels =
|
75 |
predictions = []
|
76 |
|
77 |
with torch.no_grad():
|
78 |
-
for waveforms, labels in
|
79 |
waveforms, labels = waveforms.to(device), labels.to(device)
|
80 |
|
81 |
-
# Run Model
|
82 |
outputs = model(waveforms)
|
83 |
predicted_label = torch.argmax(F.softmax(outputs, dim=1), dim=1)
|
84 |
true_labels.extend(labels.cpu().numpy())
|
|
|
5 |
import random
|
6 |
import os
|
7 |
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from Model_Loader import load_model
|
10 |
+
|
11 |
|
12 |
from .utils.evaluation import AudioEvaluationRequest
|
13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
|
17 |
|
18 |
router = APIRouter()
|
19 |
|
20 |
+
DESCRIPTION = "Quantized M5"
|
21 |
ROUTE = "/audio"
|
22 |
|
23 |
|
|
|
58 |
#--------------------------------------------------------------------------------------------
|
59 |
|
60 |
# Make random predictions (placeholder for actual model inference)
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
model_path = "quantized_teacher_m5_static.pth"
|
62 |
model, device = load_model(model_path)
|
63 |
|
64 |
+
def preprocess_audio(example, target_length=32000):
|
65 |
+
"""
|
66 |
+
Convert dataset into tensors:
|
67 |
+
- Convert to tensor
|
68 |
+
- Normalize waveform
|
69 |
+
- Pad/truncate to `target_length`
|
70 |
+
"""
|
71 |
+
waveform = torch.tensor(example["audio"]["array"], dtype=torch.float32).unsqueeze(0) # Add batch dim
|
72 |
+
|
73 |
+
# Normalize waveform
|
74 |
+
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
|
75 |
+
|
76 |
+
# Pad or truncate to fixed length
|
77 |
+
if waveform.shape[1] < target_length:
|
78 |
+
pad = torch.zeros(1, target_length - waveform.shape[1])
|
79 |
+
waveform = torch.cat((waveform, pad), dim=1) # Pad
|
80 |
+
else:
|
81 |
+
waveform = waveform[:, :target_length] # Truncate
|
82 |
+
|
83 |
+
label = torch.tensor(example["label"], dtype=torch.long) # Ensure int64
|
84 |
+
return {"waveform": waveform, "label": label}
|
85 |
+
|
86 |
+
|
87 |
|
88 |
+
train_test = train_test.map(preprocess_audio, batched=True)
|
89 |
test_dataset = train_test.map(preprocess_audio)
|
90 |
|
91 |
+
train_loader = DataLoader(train_test, batch_size=32, shuffle=True)
|
92 |
|
93 |
|
94 |
+
true_labels = train_dataset["label"]
|
95 |
predictions = []
|
96 |
|
97 |
with torch.no_grad():
|
98 |
+
for waveforms, labels in train_loader:
|
99 |
waveforms, labels = waveforms.to(device), labels.to(device)
|
100 |
|
|
|
101 |
outputs = model(waveforms)
|
102 |
predicted_label = torch.argmax(F.softmax(outputs, dim=1), dim=1)
|
103 |
true_labels.extend(labels.cpu().numpy())
|