seronk commited on
Commit
5c293a7
·
verified ·
1 Parent(s): 163b2c6

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +28 -0
tasks/audio.py CHANGED
@@ -4,6 +4,7 @@ from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
6
  import os
 
7
 
8
  from .utils.evaluation import AudioEvaluationRequest
9
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
@@ -54,7 +55,34 @@ async def evaluate_audio(request: AudioEvaluationRequest):
54
  #--------------------------------------------------------------------------------------------
55
 
56
  # Make random predictions (placeholder for actual model inference)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  true_labels = test_dataset["label"]
 
 
 
 
 
 
 
 
 
 
 
58
  predictions = [random.randint(0, 1) for _ in range(len(true_labels))]
59
 
60
  #--------------------------------------------------------------------------------------------
 
4
  from sklearn.metrics import accuracy_score
5
  import random
6
  import os
7
+ import torch
8
 
9
  from .utils.evaluation import AudioEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
55
  #--------------------------------------------------------------------------------------------
56
 
57
  # Make random predictions (placeholder for actual model inference)
58
+ def preprocess_audio(example):
59
+ """Convert dataset into tensors."""
60
+ waveform = torch.tensor(example["audio"]["array"], dtype=torch.float32).unsqueeze(0) # Add batch dim
61
+ label = torch.tensor(example["label"], dtype=torch.long) # Ensure labels are `int64`
62
+ return waveform, label
63
+
64
+ model_path = "quantized_teacher_m5_static.pth"
65
+ model, device = load_model(model_path)
66
+
67
+
68
+ train_test = train_test.map(preprocess_audio)
69
+ test_dataset = train_test.map(preprocess_audio)
70
+
71
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
72
+
73
+
74
  true_labels = test_dataset["label"]
75
+
76
+ with torch.no_grad():
77
+ for waveforms, labels in test_loader:
78
+ waveforms, labels = waveforms.to(device), labels.to(device)
79
+
80
+ # Run Model
81
+ outputs = model(waveforms)
82
+ predicted_label = torch.argmax(F.softmax(outputs, dim=1), dim=1)
83
+
84
+ true_labels.extend(labels.cpu().numpy())
85
+ predicted_labels.extend(predicted_label.cpu().numpy())
86
  predictions = [random.randint(0, 1) for _ in range(len(true_labels))]
87
 
88
  #--------------------------------------------------------------------------------------------