Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torchaudio | |
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor | |
# Load model and processor | |
model = Wav2Vec2ForSequenceClassification.from_pretrained("./") | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") | |
model.eval() | |
def classify(audio): | |
waveform, sample_rate = torchaudio.load(audio) | |
# Resample if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
waveform = resampler(waveform) | |
# Preprocess | |
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted = torch.argmax(logits, dim=-1).item() | |
return f"Predicted Keyword: {predicted}" | |
gr.Interface( | |
fn=classify, | |
inputs = gr.Audio(type="filepath", label="Record from microphone"), | |
outputs="text", | |
title="Hey Alpha Keyword Spotting", | |
).launch() | |