Update app.py
Browse files
app.py
CHANGED
@@ -64,7 +64,7 @@ if not audio_files:
|
|
64 |
print(f"✅ Found {len(audio_files)} audio files in dataset!")
|
65 |
|
66 |
# ================================
|
67 |
-
# 3️⃣ Preprocess Dataset (
|
68 |
# ================================
|
69 |
def load_and_process_audio(audio_path):
|
70 |
"""Loads and processes a single audio file into model format."""
|
@@ -73,13 +73,13 @@ def load_and_process_audio(audio_path):
|
|
73 |
# Resample to 16kHz
|
74 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
75 |
|
76 |
-
# Convert to model input format
|
77 |
-
|
78 |
|
79 |
-
return
|
80 |
|
81 |
# Manually create dataset structure
|
82 |
-
dataset = [{"
|
83 |
|
84 |
print(f"✅ Dataset Loaded! Processed {len(dataset)} audio files.")
|
85 |
|
@@ -140,12 +140,12 @@ if audio_file:
|
|
140 |
waveform, sample_rate = torchaudio.load(audio_path)
|
141 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
142 |
|
143 |
-
# Convert audio to model input
|
144 |
-
|
145 |
|
146 |
# Perform ASR inference
|
147 |
with torch.no_grad():
|
148 |
-
input_tensor = torch.tensor([
|
149 |
logits = model(input_tensor).logits
|
150 |
predicted_ids = torch.argmax(logits, dim=-1)
|
151 |
transcription = processor.batch_decode(predicted_ids)[0]
|
@@ -164,7 +164,7 @@ if audio_file:
|
|
164 |
corrected_input = processor.tokenizer(user_correction).input_ids
|
165 |
|
166 |
# Dynamically add new example to dataset
|
167 |
-
dataset.append({"
|
168 |
|
169 |
# Perform quick re-training (1 epoch)
|
170 |
trainer.args.num_train_epochs = 1
|
|
|
64 |
print(f"✅ Found {len(audio_files)} audio files in dataset!")
|
65 |
|
66 |
# ================================
|
67 |
+
# 3️⃣ Preprocess Dataset (Fixed input_features)
|
68 |
# ================================
|
69 |
def load_and_process_audio(audio_path):
|
70 |
"""Loads and processes a single audio file into model format."""
|
|
|
73 |
# Resample to 16kHz
|
74 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
75 |
|
76 |
+
# Convert to model input format (Fixed key: use input_features instead of input_values)
|
77 |
+
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
78 |
|
79 |
+
return input_features
|
80 |
|
81 |
# Manually create dataset structure
|
82 |
+
dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Load first 100
|
83 |
|
84 |
print(f"✅ Dataset Loaded! Processed {len(dataset)} audio files.")
|
85 |
|
|
|
140 |
waveform, sample_rate = torchaudio.load(audio_path)
|
141 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
142 |
|
143 |
+
# Convert audio to model input (Fixed key: use input_features)
|
144 |
+
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
145 |
|
146 |
# Perform ASR inference
|
147 |
with torch.no_grad():
|
148 |
+
input_tensor = torch.tensor([input_features]).to(device)
|
149 |
logits = model(input_tensor).logits
|
150 |
predicted_ids = torch.argmax(logits, dim=-1)
|
151 |
transcription = processor.batch_decode(predicted_ids)[0]
|
|
|
164 |
corrected_input = processor.tokenizer(user_correction).input_ids
|
165 |
|
166 |
# Dynamically add new example to dataset
|
167 |
+
dataset.append({"input_features": input_features, "labels": corrected_input})
|
168 |
|
169 |
# Perform quick re-training (1 epoch)
|
170 |
trainer.args.num_train_epochs = 1
|