xy0v0123 commited on
Commit
161cd5c
·
verified ·
1 Parent(s): 4034c0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -4
app.py CHANGED
@@ -1,7 +1,61 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
6
+ import torch.nn as nn
7
 
8
+ # Define emotions
9
+ emotion_list = ['anger', 'disgust', 'fear', 'happy', 'neutral', 'sad']
10
 
11
+ # Define the model
12
+ class EmotionClassifier(nn.Module):
13
+ def __init__(self, num_classes):
14
+ super(EmotionClassifier, self).__init__()
15
+ self.wav2vec2 = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
16
+ encoder_layer = nn.TransformerEncoderLayer(d_model=self.wav2vec2.config.hidden_size, nhead=8, batch_first=True)
17
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
18
+ self.classifier = nn.Linear(self.wav2vec2.config.hidden_size, num_classes)
19
+
20
+ def forward(self, input_values):
21
+ outputs = self.wav2vec2(input_values).last_hidden_state
22
+ encoded = self.transformer_encoder(outputs)
23
+ logits = self.classifier(encoded[:, 0, :])
24
+ return logits
25
+
26
+ # Load your trained model
27
+ model_path = "best_model_state_dict.pth"
28
+ num_classes = len(emotion_list)
29
+ model = EmotionClassifier(num_classes)
30
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
31
+ model.eval()
32
+
33
+ # Define processor
34
+ processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base')
35
+
36
+ def predict_emotion(audio):
37
+ # Load and process audio
38
+ audio, sr = librosa.load(audio, sr=16000)
39
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True).input_values
40
+ if inputs.ndimension() == 2: # Ensure correct input shape
41
+ inputs = inputs.squeeze(0)
42
+ with torch.no_grad():
43
+ logits = model(inputs.unsqueeze(0)).squeeze()
44
+
45
+ # Get predicted emotions
46
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
47
+ predictions = {emotion: float(prob) for emotion, prob in zip(emotion_list, probabilities)}
48
+ return predictions
49
+
50
+ # Create Gradio interface
51
+ interface = gr.Interface(
52
+ fn=predict_emotion,
53
+ inputs=gr.Audio(type="filepath"),
54
+ outputs=gr.Label(num_top_classes=3),
55
+ title="语音情感识别",
56
+ description="上传音频文件(.wav 或 .mp3)或录制您的声音以预测情感。"
57
+ )
58
+
59
+ # Launch the app
60
+ if __name__ == "__main__":
61
+ interface.launch()