avsv commited on
Commit
db7a8ec
·
1 Parent(s): 936f253

✅ Fix: use correct extractor for superb/wav2vec2-base-superb-er

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -21,37 +21,38 @@ def convert_to_wav(uploaded_file):
21
 
22
  def get_emotion_label(logits):
23
  emotions = ["angry", "happy", "neutral", "sad"]
24
- scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
25
  top_idx = scores.index(max(scores))
26
  return emotions[top_idx], scores
27
 
28
  def analyze_emotion(audio_path):
29
  extractor, model = load_model()
30
  waveform, sr = torchaudio.load(audio_path)
 
31
  if sr != 16000:
32
- waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
33
 
34
- inputs = extractor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
35
  with torch.no_grad():
36
  logits = model(**inputs).logits[0]
37
 
38
  emotion, scores = get_emotion_label(logits)
39
  return emotion.capitalize(), scores
40
 
41
- # Streamlit UI
42
  st.set_page_config(page_title="🎧 Audio Emotion Detector", layout="centered")
43
  st.title("🎧 Audio Emotion Analysis (Wav2Vec2)")
44
 
45
  uploaded_file = st.file_uploader("Upload an MP3 or WAV audio file", type=["mp3", "wav"])
46
 
47
  if uploaded_file:
48
- st.audio(uploaded_file, format='audio/wav')
49
  with st.spinner("Analyzing emotion..."):
50
  wav_path = convert_to_wav(uploaded_file)
51
  emotion, scores = analyze_emotion(wav_path)
52
 
53
- st.subheader("Emotion Analysis Result:")
54
- st.markdown(f"🧠 **Detected Emotion:** `{emotion}`")
55
 
56
  st.subheader("Confidence Scores:")
57
  emotions = ["angry", "happy", "neutral", "sad"]
 
21
 
22
  def get_emotion_label(logits):
23
  emotions = ["angry", "happy", "neutral", "sad"]
24
+ scores = torch.softmax(logits, dim=0).tolist()
25
  top_idx = scores.index(max(scores))
26
  return emotions[top_idx], scores
27
 
28
  def analyze_emotion(audio_path):
29
  extractor, model = load_model()
30
  waveform, sr = torchaudio.load(audio_path)
31
+
32
  if sr != 16000:
33
+ waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
34
 
35
+ inputs = extractor(waveform[0].numpy(), sampling_rate=16000, return_tensors="pt")
36
  with torch.no_grad():
37
  logits = model(**inputs).logits[0]
38
 
39
  emotion, scores = get_emotion_label(logits)
40
  return emotion.capitalize(), scores
41
 
42
+ # UI
43
  st.set_page_config(page_title="🎧 Audio Emotion Detector", layout="centered")
44
  st.title("🎧 Audio Emotion Analysis (Wav2Vec2)")
45
 
46
  uploaded_file = st.file_uploader("Upload an MP3 or WAV audio file", type=["mp3", "wav"])
47
 
48
  if uploaded_file:
49
+ st.audio(uploaded_file)
50
  with st.spinner("Analyzing emotion..."):
51
  wav_path = convert_to_wav(uploaded_file)
52
  emotion, scores = analyze_emotion(wav_path)
53
 
54
+ st.subheader("Detected Emotion:")
55
+ st.markdown(f"🧠 **{emotion}**")
56
 
57
  st.subheader("Confidence Scores:")
58
  emotions = ["angry", "happy", "neutral", "sad"]