import gradio as gr import torch import torchaudio from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor # Load model and processor model = Wav2Vec2ForSequenceClassification.from_pretrained("./") processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") model.eval() def classify(audio): waveform, sample_rate = torchaudio.load(audio) # Resample if needed if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) # Preprocess inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): logits = model(**inputs).logits predicted = torch.argmax(logits, dim=-1).item() return f"Predicted Keyword: {predicted}" gr.Interface( fn=classify, inputs = gr.Audio(type="filepath", label="Record from microphone"), outputs="text", title="Hey Alpha Keyword Spotting", ).launch()