Hey_Alpha_KWS / app.py
faizandigi009's picture
Update app.py
f28bfdd verified
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
# Load model and processor
model = Wav2Vec2ForSequenceClassification.from_pretrained("./")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model.eval()
def classify(audio):
waveform, sample_rate = torchaudio.load(audio)
# Resample if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
# Preprocess
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(**inputs).logits
predicted = torch.argmax(logits, dim=-1).item()
return f"Predicted Keyword: {predicted}"
gr.Interface(
fn=classify,
inputs = gr.Audio(type="filepath", label="Record from microphone"),
outputs="text",
title="Hey Alpha Keyword Spotting",
).launch()