kanyekuthi commited on
Commit
5cc5855
·
verified ·
1 Parent(s): a5ab88c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -1,36 +1,32 @@
1
  import gradio as gr
2
  import torch
 
3
  import torchaudio
4
- from transformers import AutoProcessor, AutoModelForCTC
5
 
6
- # Load model and processor
7
- model_id = "kanyekuthi/dsn_afrispeech"
8
  processor = AutoProcessor.from_pretrained(model_id)
9
- model = AutoModelForCTC.from_pretrained(model_id)
10
 
11
  def transcribe(audio):
12
- # Load and resample audio to 16kHz if needed
13
  waveform, sr = torchaudio.load(audio)
14
  if sr != 16000:
15
  resampler = torchaudio.transforms.Resample(sr, 16000)
16
  waveform = resampler(waveform)
17
-
18
- # Run model
19
  inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
 
20
  with torch.no_grad():
21
- logits = model(**inputs).logits
22
- predicted_ids = torch.argmax(logits, dim=-1)
23
- transcription = processor.batch_decode(predicted_ids)[0]
24
-
25
  return transcription
26
 
27
- # Build Gradio interface
28
  iface = gr.Interface(
29
  fn=transcribe,
30
  inputs=gr.Audio(source="microphone", type="filepath"),
31
  outputs="text",
32
- title="DSN Afrispeech Transcriber",
33
- description="Speak into your mic and this ASR model will transcribe it."
34
  )
35
 
36
- iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
4
  import torchaudio
 
5
 
6
+ model_id = "kanyekuthi/dsn_afrispeech" # or your correct model repo ID
7
+
8
  processor = AutoProcessor.from_pretrained(model_id)
9
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
10
 
11
  def transcribe(audio):
 
12
  waveform, sr = torchaudio.load(audio)
13
  if sr != 16000:
14
  resampler = torchaudio.transforms.Resample(sr, 16000)
15
  waveform = resampler(waveform)
16
+
 
17
  inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
18
+
19
  with torch.no_grad():
20
+ generated_ids = model.generate(**inputs)
21
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
22
+
 
23
  return transcription
24
 
 
25
  iface = gr.Interface(
26
  fn=transcribe,
27
  inputs=gr.Audio(source="microphone", type="filepath"),
28
  outputs="text",
29
+ title="Whisper-based ASR Demo"
 
30
  )
31
 
32
+ iface.launch()