seronk commited on
Commit
99f1647
·
verified ·
1 Parent(s): 98e71a8

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +30 -28
tasks/audio.py CHANGED
@@ -60,47 +60,49 @@ async def evaluate_audio(request: AudioEvaluationRequest):
60
  model_path = "quantized_teacher_m5_static.pth"
61
  model, device = load_model(model_path)
62
 
63
- def preprocess_audio(example, target_length=32000):
64
- """
65
- Convert dataset into tensors:
66
- - Convert to tensor
67
- - Normalize waveform
68
- - Pad/truncate to `target_length`
69
- """
70
- waveform = torch.tensor(example["audio"]["array"], dtype=torch.float32).unsqueeze(0) # Add batch dim
71
 
72
- # Normalize waveform
73
- waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
74
 
75
- # Pad or truncate to fixed length
76
- if waveform.shape[1] < target_length:
77
- pad = torch.zeros(1, target_length - waveform.shape[1])
78
- waveform = torch.cat((waveform, pad), dim=1) # Pad
79
- else:
80
- waveform = waveform[:, :target_length] # Truncate
81
 
82
- label = torch.tensor(example["label"], dtype=torch.long) # Ensure int64
83
- return {"waveform": waveform, "label": label}
84
 
85
 
86
 
87
- train_test = train_test.map(preprocess_audio, batched=True)
88
- test_dataset = train_test.map(preprocess_audio)
89
 
90
- train_loader = DataLoader(train_test, batch_size=32, shuffle=True)
91
 
92
 
93
  true_labels = train_dataset["label"]
94
  predictions = []
 
 
95
 
96
- with torch.no_grad():
97
- for waveforms, labels in train_loader:
98
- waveforms, labels = waveforms.to(device), labels.to(device)
99
 
100
- outputs = model(waveforms)
101
- predicted_label = torch.argmax(F.softmax(outputs, dim=1), dim=1)
102
- true_labels.extend(labels.cpu().numpy())
103
- predicted_labels.extend(predicted_label.cpu().numpy())
104
 
105
  #--------------------------------------------------------------------------------------------
106
  # YOUR MODEL INFERENCE STOPS HERE
 
60
  model_path = "quantized_teacher_m5_static.pth"
61
  model, device = load_model(model_path)
62
 
63
+ # def preprocess_audio(example, target_length=32000):
64
+ # """
65
+ # Convert dataset into tensors:
66
+ # - Convert to tensor
67
+ # - Normalize waveform
68
+ # - Pad/truncate to `target_length`
69
+ # """
70
+ # waveform = torch.tensor(example["audio"]["array"], dtype=torch.float32).unsqueeze(0) # Add batch dim
71
 
72
+ # # Normalize waveform
73
+ # waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
74
 
75
+ # # Pad or truncate to fixed length
76
+ # if waveform.shape[1] < target_length:
77
+ # pad = torch.zeros(1, target_length - waveform.shape[1])
78
+ # waveform = torch.cat((waveform, pad), dim=1) # Pad
79
+ # else:
80
+ # waveform = waveform[:, :target_length] # Truncate
81
 
82
+ # label = torch.tensor(example["label"], dtype=torch.long) # Ensure int64
83
+ # return {"waveform": waveform, "label": label}
84
 
85
 
86
 
87
+ # train_test = train_test.map(preprocess_audio, batched=True)
88
+ # test_dataset = train_test.map(preprocess_audio)
89
 
90
+ # train_loader = DataLoader(train_test, batch_size=32, shuffle=True)
91
 
92
 
93
  true_labels = train_dataset["label"]
94
  predictions = []
95
+
96
+ predictions = [random.randint(0, 1) for _ in range(len(true_labels))]
97
 
98
+ # with torch.no_grad():
99
+ # for waveforms, labels in train_loader:
100
+ # waveforms, labels = waveforms.to(device), labels.to(device)
101
 
102
+ # outputs = model(waveforms)
103
+ # predicted_label = torch.argmax(F.softmax(outputs, dim=1), dim=1)
104
+ # true_labels.extend(labels.cpu().numpy())
105
+ # predicted_labels.extend(predicted_label.cpu().numpy())
106
 
107
  #--------------------------------------------------------------------------------------------
108
  # YOUR MODEL INFERENCE STOPS HERE