Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torchaudio | |
import tempfile | |
from pydub import AudioSegment | |
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification | |
def load_model(): | |
extractor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er") | |
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er") | |
model.eval() | |
return extractor, model | |
def convert_to_wav(uploaded_file): | |
audio = AudioSegment.from_file(uploaded_file) | |
audio = audio.set_frame_rate(16000).set_channels(1) | |
temp_path = tempfile.mktemp(suffix=".wav") | |
audio.export(temp_path, format="wav") | |
return temp_path | |
def get_emotion_label(logits): | |
emotions = ["angry", "happy", "neutral", "sad"] | |
scores = torch.softmax(logits, dim=0).tolist() | |
top_idx = scores.index(max(scores)) | |
return emotions[top_idx], scores | |
def analyze_emotion(audio_path): | |
extractor, model = load_model() | |
waveform, sr = torchaudio.load(audio_path) | |
# π‘ Trim audio to 30 seconds max to avoid slowdowns | |
max_duration_sec = 30 | |
max_samples = sr * max_duration_sec | |
if waveform.size(1) > max_samples: | |
waveform = waveform[:, :max_samples] | |
duration_sec = waveform.size(1) / sr | |
# Run model | |
inputs = extractor(waveform[0].numpy(), sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits[0] | |
emotion, scores = get_emotion_label(logits) | |
return emotion.capitalize(), scores, duration_sec | |
# Streamlit UI | |
st.set_page_config(page_title="π§ Audio Emotion Detector", layout="centered") | |
st.title("π§ Audio Emotion Analysis (Wav2Vec2)") | |
uploaded_file = st.file_uploader("Upload an MP3 or WAV audio file", type=["mp3", "wav"]) | |
if uploaded_file: | |
st.audio(uploaded_file, format='audio/wav') | |
with st.spinner("Analyzing emotion..."): | |
wav_path = convert_to_wav(uploaded_file) | |
emotion, scores, duration_sec = analyze_emotion(wav_path) | |
st.subheader("β± Audio Info:") | |
st.write(f"Duration analyzed: **{duration_sec:.2f} seconds**") | |
st.subheader("π§ Detected Emotion:") | |
st.markdown(f"**{emotion}**") | |
st.subheader("π― Confidence Scores:") | |
emotions = ["angry", "happy", "neutral", "sad"] | |
for i, label in enumerate(emotions): | |
st.write(f"- **{label.capitalize()}**: {scores[i]*100:.2f}%") | |