Spaces:
Running
Running
import typing | |
import types # fusion of forward() of Wav2Vec2 | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import spaces | |
import torch | |
import torch.nn as nn | |
from transformers import Wav2Vec2Processor | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel | |
import audiofile | |
import audresample | |
device = 0 if torch.cuda.is_available() else "cpu" | |
duration = 2 # limit processing of audio | |
age_gender_model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender" | |
expression_model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" | |
class AgeGenderHead(nn.Module): | |
r"""Age-gender model head.""" | |
def __init__(self, config, num_labels): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.dropout = nn.Dropout(config.final_dropout) | |
self.out_proj = nn.Linear(config.hidden_size, num_labels) | |
def forward(self, features, **kwargs): | |
x = features | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class AgeGenderModel(Wav2Vec2PreTrainedModel): | |
r"""Age-gender recognition model.""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.wav2vec2 = Wav2Vec2Model(config) | |
self.age = AgeGenderHead(config, 1) | |
self.gender = AgeGenderHead(config, 3) | |
self.init_weights() | |
def forward( | |
self, | |
frozen_cnn7, | |
): | |
hidden_states = self.wav2vec2(frozen_cnn7=frozen_cnn7) # runs only Transformer layers | |
hidden_states = torch.mean(hidden_states, dim=1) | |
logits_age = self.age(hidden_states) | |
logits_gender = torch.softmax(self.gender(hidden_states), dim=1) | |
return hidden_states, logits_age, logits_gender | |
# AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel | |
def _forward( | |
self, | |
frozen_cnn7=None, # CNN7 fetures of wav2vec2 calc. from CNN7 feature extractor (once) | |
attention_mask=None): | |
if attention_mask is not None: | |
# compute reduced attention_mask corresponding to feature vectors | |
attention_mask = self._get_feature_vector_attention_mask( | |
frozen_cnn7.shape[1], attention_mask, add_adapter=False | |
) | |
hidden_states, _ = self.wav2vec2.feature_projection(frozen_cnn7) | |
hidden_states = self.wav2vec2.encoder( | |
hidden_states, | |
attention_mask=attention_mask, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
)[0] | |
return hidden_states | |
def _forward_and_cnn7( | |
self, | |
input_values, | |
attention_mask=None): | |
frozen_cnn7 = self.wav2vec2.feature_extractor(input_values) | |
frozen_cnn7 = frozen_cnn7.transpose(1, 2) | |
if attention_mask is not None: | |
# compute reduced attention_mask corresponding to feature vectors | |
attention_mask = self.wav2vec2._get_feature_vector_attention_mask( | |
frozen_cnn7.shape[1], attention_mask, add_adapter=False | |
) | |
hidden_states, _ = self.wav2vec2.feature_projection(frozen_cnn7) # grad=True non frozen | |
hidden_states = self.wav2vec2.encoder( | |
hidden_states, | |
attention_mask=attention_mask, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
)[0] | |
return hidden_states, frozen_cnn7 #feature_proj is trainable thus we have to access the frozen_cnn7 before projection layer | |
class ExpressionHead(nn.Module): | |
r"""Expression model head.""" | |
def __init__(self, config): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.dropout = nn.Dropout(config.final_dropout) | |
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
def forward(self, features, **kwargs): | |
x = features | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class ExpressionModel(Wav2Vec2PreTrainedModel): | |
r"""speech expression model.""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.wav2vec2 = Wav2Vec2Model(config) | |
self.classifier = ExpressionHead(config) | |
self.init_weights() | |
def forward(self, input_values): | |
hidden_states, frozen_cnn7 = self.wav2vec2(input_values) | |
hidden_states = torch.mean(hidden_states, dim=1) | |
logits = self.classifier(hidden_states) | |
return hidden_states, logits, frozen_cnn7 | |
# Load models from hub | |
age_gender_processor = Wav2Vec2Processor.from_pretrained(age_gender_model_name) | |
age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name) | |
expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name) | |
expression_model = ExpressionModel.from_pretrained(expression_model_name) | |
# Emotion Calc. CNN features | |
age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model) | |
expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model) | |
def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]: | |
# batch audio | |
y = expression_processor(x, sampling_rate=sampling_rate) | |
y = y['input_values'][0] | |
y = y.reshape(1, -1) | |
y = torch.from_numpy(y).to(device) | |
# run through expression model | |
with torch.no_grad(): | |
_, logits_expression, frozen_cnn7 = expression_model(y) | |
_, logits_age, logits_gender = age_gender_model(frozen_cnn7=frozen_cnn7) | |
# Plot A/D/V values | |
plot_expression(logits_expression[0, 0].item(), # implicit detach().cpu().numpy() | |
logits_expression[0, 1].item(), | |
logits_expression[0, 2].item()) | |
expression_file = "expression.png" | |
plt.savefig(expression_file) | |
return ( | |
f"{round(100 * logits_age[0, 0].item())} years", # age | |
{ | |
"female": logits_gender[0, 0].item(), | |
"male": logits_gender[0, 1].item(), | |
"child": logits_gender[0, 2].item(), | |
}, | |
expression_file, | |
) | |
def recognize(input_file: str) -> typing.Tuple[str, dict, str]: | |
# sampling_rate, signal = input_microphone | |
# signal = signal.astype(np.float32, order="C") / 32768.0 | |
if input_file is None: | |
raise gr.Error( | |
"No audio file submitted! " | |
"Please upload or record an audio file " | |
"before submitting your request." | |
) | |
signal, sampling_rate = audiofile.read(input_file, duration=duration) | |
# Resample to sampling rate supported byu the models | |
target_rate = 16000 | |
signal = audresample.resample(signal, sampling_rate, target_rate) | |
return process_func(signal, target_rate) | |
def plot_expression_RIGID(arousal, dominance, valence): | |
r"""3D pixel plot of arousal, dominance, valence.""" | |
# Voxels per dimension | |
voxels = 7 | |
# Create voxel grid | |
x, y, z = np.indices((voxels + 1, voxels + 1, voxels + 1)) | |
voxel = ( | |
(x == round(arousal * voxels)) | |
& (y == round(dominance * voxels)) | |
& (z == round(valence * voxels)) | |
) | |
projection = ( | |
(x == round(arousal * voxels)) | |
& (y == round(dominance * voxels)) | |
& (z < round(valence * voxels)) | |
) | |
colors = np.empty((voxel | projection).shape, dtype=object) | |
colors[voxel] = "#fcb06c" | |
colors[projection] = "#fed7a9" | |
ax = plt.figure().add_subplot(projection='3d') | |
ax.voxels(voxel | projection, facecolors=colors, edgecolor='k') | |
ax.set_xlim([0, voxels]) | |
ax.set_ylim([0, voxels]) | |
ax.set_zlim([0, voxels]) | |
ax.set_aspect("equal") | |
ax.set_xlabel("arousal", fontsize="large", labelpad=0) | |
ax.set_ylabel("dominance", fontsize="large", labelpad=0) | |
ax.set_zlabel("valence", fontsize="large", labelpad=0) | |
ax.set_xticks( | |
list(range(voxels + 1)), | |
labels=[0, None, None, None, None, None, None, 1], | |
verticalalignment="bottom", | |
) | |
ax.set_yticks( | |
list(range(voxels + 1)), | |
labels=[0, None, None, None, None, None, None, 1], | |
verticalalignment="bottom", | |
) | |
ax.set_zticks( | |
list(range(voxels + 1)), | |
labels=[0, None, None, None, None, None, None, 1], | |
verticalalignment="top", | |
) | |
def explode(data): | |
""" | |
Expands a 3D array by creating gaps between voxels. | |
This function is used to create the visual separation between the voxels. | |
""" | |
shape_orig = np.array(data.shape) | |
shape_new = shape_orig * 2 - 1 | |
retval = np.zeros(shape_new, dtype=data.dtype) | |
retval[::2, ::2, ::2] = data | |
return retval | |
def plot_expression(arousal, dominance, valence): | |
'''_h = cuda tensor (N_PIX, N_PIX, N_PIX)''' | |
N_PIX = 5 | |
_h = np.random.rand(N_PIX, N_PIX, N_PIX) * 1e-3 | |
adv = np.array([arousal, .994 - dominance, valence]).clip(0, .99) | |
arousal, dominance, valence = (adv * N_PIX).astype(np.int64) # find voxel | |
_h[arousal, dominance, valence] = .22 | |
filled = np.ones((N_PIX, N_PIX, N_PIX), dtype=bool) | |
# upscale the above voxel image, leaving gaps | |
filled_2 = explode(filled) | |
# Shrink the gaps | |
x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2 | |
x[1::2, :, :] += 1 | |
y[:, 1::2, :] += 1 | |
z[:, :, 1::2] += 1 | |
ax = plt.figure().add_subplot(projection='3d') | |
f_2 = np.ones([2 * N_PIX - 1, | |
2 * N_PIX - 1, | |
2 * N_PIX - 1, 4], dtype=np.float64) | |
f_2[:, :, :, 3] = explode(_h) | |
cm = plt.get_cmap('cool') | |
f_2[:, :, :, :3] = cm(f_2[:, :, :, 3])[..., :3] | |
f_2[:, :, :, 3] = f_2[:, :, :, 3].clip(.01, .74) | |
print(f_2.shape, 'f_2 AAAA') | |
ecolors_2 = f_2 | |
ax.voxels(x, y, z, filled_2, facecolors=f_2, edgecolors=.006 * ecolors_2) | |
ax.set_aspect('equal') | |
ax.set_zticks([0, N_PIX]) | |
ax.set_xticks([0, N_PIX]) | |
ax.set_yticks([0, N_PIX]) | |
ax.set_zticklabels([f'{n/N_PIX:.2f}'[0:] for n in ax.get_zticks()]) | |
ax.set_zlabel('valence', fontsize=10, labelpad=0) | |
ax.set_xticklabels([f'{n/N_PIX:.2f}' for n in ax.get_xticks()]) | |
ax.set_xlabel('arousal', fontsize=10, labelpad=7) | |
# The y-axis rotation is corrected here from 275 to 90 degrees | |
ax.set_yticklabels([f'{1-n/N_PIX:.2f}' for n in ax.get_yticks()], rotation=90) | |
ax.set_ylabel('dominance', fontsize=10, labelpad=10) | |
ax.grid(False) | |
ax.plot([N_PIX, N_PIX], [0, N_PIX + .2], [N_PIX, N_PIX], 'g', linewidth=1) | |
ax.plot([0, N_PIX], [N_PIX, N_PIX + .24], [N_PIX, N_PIX], 'k', linewidth=1) | |
# Missing lines on the top face | |
ax.plot([0, 0], [0, N_PIX], [N_PIX, N_PIX], 'darkred', linewidth=1) | |
ax.plot([0, N_PIX], [0, 0], [N_PIX, N_PIX], 'darkblue', linewidth=1) | |
# Set pane colors after plotting the lines | |
# UPDATED: Replaced `w_xaxis` with `xaxis` and `w_yaxis` with `yaxis`. | |
ax.xaxis.set_pane_color((0.8, 0.8, 0.8, 0.5)) | |
ax.yaxis.set_pane_color((0.8, 0.8, 0.8, 0.5)) | |
ax.zaxis.set_pane_color((0.8, 0.8, 0.8, 0.0)) | |
# Restore the limits to prevent the plot from expanding | |
ax.set_xlim(0, N_PIX) | |
ax.set_ylim(0, N_PIX) | |
ax.set_zlim(0, N_PIX) | |
#plt.show() | |
# ------ | |
description = ( | |
"Estimate **age**, **gender**, and **expression** " | |
"of the speaker contained in an audio file or microphone recording. \n" | |
f"The model [{age_gender_model_name}]" | |
f"(https://huggingface.co/{age_gender_model_name}) " | |
"recognises age and gender, " | |
f"whereas [{expression_model_name}]" | |
f"(https://huggingface.co/{expression_model_name}) " | |
"recognises the expression dimensions arousal, dominance, and valence. " | |
) | |
with gr.Blocks() as demo: | |
with gr.Tab(label="Speech analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(description) | |
input = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Audio input", | |
min_length=0.025, # seconds | |
) | |
gr.Examples( | |
[ | |
"female-46-neutral.wav", | |
"female-20-happy.wav", | |
"male-60-angry.wav", | |
"male-27-sad.wav", | |
], | |
[input], | |
label="Examples from CREMA-D, ODbL v1.0 license", | |
) | |
gr.Markdown("Only the first two seconds of the audio will be processed.") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
output_age = gr.Textbox(label="Age") | |
output_gender = gr.Label(label="Gender") | |
output_expression = gr.Image(label="Expression") | |
outputs = [output_age, output_gender, output_expression] | |
submit_btn.click(recognize, input, outputs) | |
demo.launch(debug=True) | |