Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import librosa | |
import numpy as np | |
import torch.nn.functional as F | |
import os | |
from encoders.transformer import Wav2Vec2EmotionClassifier | |
# Define the emotions | |
emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"] | |
label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)} | |
# Load the trained model | |
model_path = "lora_only_model.pth" | |
cfg = { | |
"model": { | |
"encoder": "Wav2Vec2Classifier", | |
"optimizer": { | |
"name": "Adam", | |
"lr": 0.0003, | |
"weight_decay": 3e-4 | |
}, | |
"l1_lambda": 0.0 | |
} | |
} | |
model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"]) | |
state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
model.load_state_dict(state_dict, strict=False) | |
model.eval() | |
for name, param in model.named_parameters(): | |
if param.requires_grad: | |
print(f"{name}: {param.data}") | |
# Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors | |
MIN_SAMPLES = 10 # or 16000 if you want at least 1 second | |
# Preprocessing function | |
def preprocess_audio(file_path, sample_rate=16000): | |
""" | |
Safely loads the file at file_path and returns a (1, samples) torch tensor. | |
Returns None if the file is invalid or too short. | |
""" | |
if not file_path or (not os.path.exists(file_path)): | |
# file_path could be None or an empty string if user didn't record properly | |
return None | |
# Load with librosa (which merges to mono by default if multi-channel) | |
waveform, sr = librosa.load(file_path, sr=sample_rate) | |
# Check length | |
if len(waveform) < MIN_SAMPLES: | |
return None | |
# Convert to torch tensor, shape (1, samples) | |
waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0) | |
return waveform_tensor | |
# Prediction function | |
def predict_emotion(audio_file): | |
""" | |
audio_file is a file path from Gradio (type='filepath'). | |
""" | |
# Preprocess | |
waveform = preprocess_audio(audio_file, sample_rate=16000) | |
# If invalid or too short, return an error-like message | |
if waveform is None: | |
return ( | |
"Audio is too short or invalid. Please record/upload a longer clip.", | |
"" | |
) | |
# Perform inference | |
with torch.no_grad(): | |
logits = model(waveform) | |
probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0] | |
# Get the predicted class | |
predicted_class = np.argmax(probabilities) | |
predicted_emotion = label_mapping[str(predicted_class)] | |
# Format probabilities for visualization | |
probabilities_output = [ | |
f""" | |
<div style='display: flex; align-items: center; margin: 5px 0;'> | |
<div style='width: 20%; text-align: right; margin-right: 10px; font-weight: bold;'>{emotions[i]}</div> | |
<div style='flex-grow: 1; background-color: #374151; border-radius: 4px; overflow: hidden;'> | |
<div style='width: {probabilities[i]*100:.2f}%; background-color: #FFA500; height: 10px;'></div> | |
</div> | |
<div style='width: 10%; text-align: right; margin-left: 10px;'>{probabilities[i]*100:.2f}%</div> | |
</div> | |
""" | |
for i in range(len(emotions)) | |
] | |
return predicted_emotion, "\n".join(probabilities_output) | |
# Create Gradio interface | |
def gradio_interface(audio): | |
detected_emotion, probabilities_html = predict_emotion(audio) | |
return detected_emotion, gr.HTML(probabilities_html) | |
# Define Gradio UI | |
with gr.Blocks(css=""" | |
body { | |
background-color: #121212; | |
color: white; | |
font-family: Arial, sans-serif; | |
} | |
h1 { | |
color: #FFA500; | |
font-size: 48px; | |
text-align: center; | |
margin-bottom: 10px; | |
} | |
p { | |
text-align: center; | |
font-size: 18px; | |
} | |
.gradio-row { | |
justify-content: center; | |
align-items: center; | |
} | |
#submit_button { | |
background-color: #FFA500 !important; | |
color: black !important; | |
font-size: 18px; | |
padding: 10px 20px; | |
margin-top: 20px; | |
} | |
#detected_emotion { | |
font-size: 24px; | |
font-weight: bold; | |
text-align: center; | |
} | |
.probabilities-container { | |
margin-top: 20px; | |
padding: 10px; | |
background-color: #1F2937; | |
border-radius: 8px; | |
} | |
""") as demo: | |
gr.Markdown( | |
""" | |
<div> | |
<h1>Speech Emotion Recognition</h1> | |
<p>๐ต Upload or record an audio file (max 1 minute) to detect emotions.</p> | |
<p>Supported Emotions: ๐ Happy | ๐ญ Sad | ๐ก Angry | ๐ Neutral | ๐จ Fear | ๐คข Disgust | ๐ฎ Surprise</p> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, elem_id="audio-block"): | |
# type="filepath" means we get a temporary file path from Gradio | |
audio_input = gr.Audio(label="๐ค Record or Upload Audio", type="filepath") | |
submit_button = gr.Button("Submit", elem_id="submit_button") | |
with gr.Column(scale=1): | |
detected_emotion_label = gr.Label(label="Detected Emotion", elem_id="detected_emotion") | |
probabilities_html = gr.HTML(label="Probabilities", elem_id="probabilities") | |
submit_button.click( | |
fn=gradio_interface, | |
inputs=audio_input, | |
outputs=[detected_emotion_label, probabilities_html] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=True) | |