saeedbenadeeb's picture
Lora Model Uploaded
5fc7eb1
import gradio as gr
import torch
import librosa
import numpy as np
import torch.nn.functional as F
import os
from encoders.transformer import Wav2Vec2EmotionClassifier
# Define the emotions
emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}
# Load the trained model
model_path = "lora_only_model.pth"
cfg = {
"model": {
"encoder": "Wav2Vec2Classifier",
"optimizer": {
"name": "Adam",
"lr": 0.0003,
"weight_decay": 3e-4
},
"l1_lambda": 0.0
}
}
model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict, strict=False)
model.eval()
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}: {param.data}")
# Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
MIN_SAMPLES = 10 # or 16000 if you want at least 1 second
# Preprocessing function
def preprocess_audio(file_path, sample_rate=16000):
"""
Safely loads the file at file_path and returns a (1, samples) torch tensor.
Returns None if the file is invalid or too short.
"""
if not file_path or (not os.path.exists(file_path)):
# file_path could be None or an empty string if user didn't record properly
return None
# Load with librosa (which merges to mono by default if multi-channel)
waveform, sr = librosa.load(file_path, sr=sample_rate)
# Check length
if len(waveform) < MIN_SAMPLES:
return None
# Convert to torch tensor, shape (1, samples)
waveform_tensor = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)
return waveform_tensor
# Prediction function
def predict_emotion(audio_file):
"""
audio_file is a file path from Gradio (type='filepath').
"""
# Preprocess
waveform = preprocess_audio(audio_file, sample_rate=16000)
# If invalid or too short, return an error-like message
if waveform is None:
return (
"Audio is too short or invalid. Please record/upload a longer clip.",
""
)
# Perform inference
with torch.no_grad():
logits = model(waveform)
probabilities = F.softmax(logits, dim=-1).cpu().numpy()[0]
# Get the predicted class
predicted_class = np.argmax(probabilities)
predicted_emotion = label_mapping[str(predicted_class)]
# Format probabilities for visualization
probabilities_output = [
f"""
<div style='display: flex; align-items: center; margin: 5px 0;'>
<div style='width: 20%; text-align: right; margin-right: 10px; font-weight: bold;'>{emotions[i]}</div>
<div style='flex-grow: 1; background-color: #374151; border-radius: 4px; overflow: hidden;'>
<div style='width: {probabilities[i]*100:.2f}%; background-color: #FFA500; height: 10px;'></div>
</div>
<div style='width: 10%; text-align: right; margin-left: 10px;'>{probabilities[i]*100:.2f}%</div>
</div>
"""
for i in range(len(emotions))
]
return predicted_emotion, "\n".join(probabilities_output)
# Create Gradio interface
def gradio_interface(audio):
detected_emotion, probabilities_html = predict_emotion(audio)
return detected_emotion, gr.HTML(probabilities_html)
# Define Gradio UI
with gr.Blocks(css="""
body {
background-color: #121212;
color: white;
font-family: Arial, sans-serif;
}
h1 {
color: #FFA500;
font-size: 48px;
text-align: center;
margin-bottom: 10px;
}
p {
text-align: center;
font-size: 18px;
}
.gradio-row {
justify-content: center;
align-items: center;
}
#submit_button {
background-color: #FFA500 !important;
color: black !important;
font-size: 18px;
padding: 10px 20px;
margin-top: 20px;
}
#detected_emotion {
font-size: 24px;
font-weight: bold;
text-align: center;
}
.probabilities-container {
margin-top: 20px;
padding: 10px;
background-color: #1F2937;
border-radius: 8px;
}
""") as demo:
gr.Markdown(
"""
<div>
<h1>Speech Emotion Recognition</h1>
<p>๐ŸŽต Upload or record an audio file (max 1 minute) to detect emotions.</p>
<p>Supported Emotions: ๐Ÿ˜Š Happy | ๐Ÿ˜ญ Sad | ๐Ÿ˜ก Angry | ๐Ÿ˜ Neutral | ๐Ÿ˜จ Fear | ๐Ÿคข Disgust | ๐Ÿ˜ฎ Surprise</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1, elem_id="audio-block"):
# type="filepath" means we get a temporary file path from Gradio
audio_input = gr.Audio(label="๐ŸŽค Record or Upload Audio", type="filepath")
submit_button = gr.Button("Submit", elem_id="submit_button")
with gr.Column(scale=1):
detected_emotion_label = gr.Label(label="Detected Emotion", elem_id="detected_emotion")
probabilities_html = gr.HTML(label="Probabilities", elem_id="probabilities")
submit_button.click(
fn=gradio_interface,
inputs=audio_input,
outputs=[detected_emotion_label, probabilities_html]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=True)