tahirsher commited on
Commit
7d7504d
·
verified ·
1 Parent(s): 1bb8243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -64,7 +64,7 @@ if not audio_files:
64
  print(f"✅ Found {len(audio_files)} audio files in dataset!")
65
 
66
  # ================================
67
- # 3️⃣ Preprocess Dataset (Manually)
68
  # ================================
69
  def load_and_process_audio(audio_path):
70
  """Loads and processes a single audio file into model format."""
@@ -73,13 +73,13 @@ def load_and_process_audio(audio_path):
73
  # Resample to 16kHz
74
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
75
 
76
- # Convert to model input format
77
- input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
78
 
79
- return input_values
80
 
81
  # Manually create dataset structure
82
- dataset = [{"input_values": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Load first 100
83
 
84
  print(f"✅ Dataset Loaded! Processed {len(dataset)} audio files.")
85
 
@@ -140,12 +140,12 @@ if audio_file:
140
  waveform, sample_rate = torchaudio.load(audio_path)
141
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
142
 
143
- # Convert audio to model input
144
- input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
145
 
146
  # Perform ASR inference
147
  with torch.no_grad():
148
- input_tensor = torch.tensor([input_values]).to(device)
149
  logits = model(input_tensor).logits
150
  predicted_ids = torch.argmax(logits, dim=-1)
151
  transcription = processor.batch_decode(predicted_ids)[0]
@@ -164,7 +164,7 @@ if audio_file:
164
  corrected_input = processor.tokenizer(user_correction).input_ids
165
 
166
  # Dynamically add new example to dataset
167
- dataset.append({"input_values": input_values, "labels": corrected_input})
168
 
169
  # Perform quick re-training (1 epoch)
170
  trainer.args.num_train_epochs = 1
 
64
  print(f"✅ Found {len(audio_files)} audio files in dataset!")
65
 
66
  # ================================
67
+ # 3️⃣ Preprocess Dataset (Fixed input_features)
68
  # ================================
69
  def load_and_process_audio(audio_path):
70
  """Loads and processes a single audio file into model format."""
 
73
  # Resample to 16kHz
74
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
75
 
76
+ # Convert to model input format (Fixed key: use input_features instead of input_values)
77
+ input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
78
 
79
+ return input_features
80
 
81
  # Manually create dataset structure
82
+ dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Load first 100
83
 
84
  print(f"✅ Dataset Loaded! Processed {len(dataset)} audio files.")
85
 
 
140
  waveform, sample_rate = torchaudio.load(audio_path)
141
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
142
 
143
+ # Convert audio to model input (Fixed key: use input_features)
144
+ input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
145
 
146
  # Perform ASR inference
147
  with torch.no_grad():
148
+ input_tensor = torch.tensor([input_features]).to(device)
149
  logits = model(input_tensor).logits
150
  predicted_ids = torch.argmax(logits, dim=-1)
151
  transcription = processor.batch_decode(predicted_ids)[0]
 
164
  corrected_input = processor.tokenizer(user_correction).input_ids
165
 
166
  # Dynamically add new example to dataset
167
+ dataset.append({"input_features": input_features, "labels": corrected_input})
168
 
169
  # Perform quick re-training (1 epoch)
170
  trainer.args.num_train_epochs = 1