import streamlit as st import torch import torchaudio import tempfile from pydub import AudioSegment from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification @st.cache_resource def load_model(): extractor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er") model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er") model.eval() return extractor, model def convert_to_wav(uploaded_file): audio = AudioSegment.from_file(uploaded_file) audio = audio.set_frame_rate(16000).set_channels(1) temp_path = tempfile.mktemp(suffix=".wav") audio.export(temp_path, format="wav") return temp_path def get_emotion_label(logits): emotions = ["angry", "happy", "neutral", "sad"] scores = torch.softmax(logits, dim=0).tolist() top_idx = scores.index(max(scores)) return emotions[top_idx], scores def analyze_emotion(audio_path): extractor, model = load_model() waveform, sr = torchaudio.load(audio_path) # 💡 Trim audio to 30 seconds max to avoid slowdowns max_duration_sec = 30 max_samples = sr * max_duration_sec if waveform.size(1) > max_samples: waveform = waveform[:, :max_samples] duration_sec = waveform.size(1) / sr # Run model inputs = extractor(waveform[0].numpy(), sampling_rate=16000, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits[0] emotion, scores = get_emotion_label(logits) return emotion.capitalize(), scores, duration_sec # Streamlit UI st.set_page_config(page_title="🎧 Audio Emotion Detector", layout="centered") st.title("🎧 Audio Emotion Analysis (Wav2Vec2)") uploaded_file = st.file_uploader("Upload an MP3 or WAV audio file", type=["mp3", "wav"]) if uploaded_file: st.audio(uploaded_file, format='audio/wav') with st.spinner("Analyzing emotion..."): wav_path = convert_to_wav(uploaded_file) emotion, scores, duration_sec = analyze_emotion(wav_path) st.subheader("⏱ Audio Info:") st.write(f"Duration analyzed: **{duration_sec:.2f} seconds**") st.subheader("🧠 Detected Emotion:") st.markdown(f"**{emotion}**") st.subheader("🎯 Confidence Scores:") emotions = ["angry", "happy", "neutral", "sad"] for i, label in enumerate(emotions): st.write(f"- **{label.capitalize()}**: {scores[i]*100:.2f}%")