Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,61 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Processor
|
6 |
+
import torch.nn as nn
|
7 |
|
8 |
+
# Define emotions
|
9 |
+
emotion_list = ['anger', 'disgust', 'fear', 'happy', 'neutral', 'sad']
|
10 |
|
11 |
+
# Define the model
|
12 |
+
class EmotionClassifier(nn.Module):
|
13 |
+
def __init__(self, num_classes):
|
14 |
+
super(EmotionClassifier, self).__init__()
|
15 |
+
self.wav2vec2 = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
|
16 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=self.wav2vec2.config.hidden_size, nhead=8, batch_first=True)
|
17 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
18 |
+
self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, num_classes)
|
19 |
+
|
20 |
+
def forward(self, input_values):
|
21 |
+
outputs = self.wav2vec2(input_values).last_hidden_state
|
22 |
+
encoded = self.transformer_encoder(outputs)
|
23 |
+
logits = self.classifier(encoded[:, 0, :])
|
24 |
+
return logits
|
25 |
+
|
26 |
+
# Load your trained model
|
27 |
+
model_path = "best_model_state_dict.pth"
|
28 |
+
num_classes = len(emotion_list)
|
29 |
+
model = EmotionClassifier(num_classes)
|
30 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
# Define processor
|
34 |
+
processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base')
|
35 |
+
|
36 |
+
def predict_emotion(audio):
|
37 |
+
# Load and process audio
|
38 |
+
audio, sr = librosa.load(audio, sr=16000)
|
39 |
+
inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True).input_values
|
40 |
+
if inputs.ndimension() == 2: # Ensure correct input shape
|
41 |
+
inputs = inputs.squeeze(0)
|
42 |
+
with torch.no_grad():
|
43 |
+
logits = model(inputs.unsqueeze(0)).squeeze()
|
44 |
+
|
45 |
+
# Get predicted emotions
|
46 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
|
47 |
+
predictions = {emotion: float(prob) for emotion, prob in zip(emotion_list, probabilities)}
|
48 |
+
return predictions
|
49 |
+
|
50 |
+
# Create Gradio interface
|
51 |
+
interface = gr.Interface(
|
52 |
+
fn=predict_emotion,
|
53 |
+
inputs=gr.Audio(type="filepath"),
|
54 |
+
outputs=gr.Label(num_top_classes=3),
|
55 |
+
title="语音情感识别",
|
56 |
+
description="上传音频文件(.wav 或 .mp3)或录制您的声音以预测情感。"
|
57 |
+
)
|
58 |
+
|
59 |
+
# Launch the app
|
60 |
+
if __name__ == "__main__":
|
61 |
+
interface.launch()
|