seayala commited on
Commit
c12eabf
verified
1 Parent(s): a2ca1c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -29
app.py CHANGED
@@ -1,30 +1,89 @@
1
  import gradio as gr
2
- import torch
3
- import torchaudio
4
-
5
- # Cargar el modelo
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
- model = M5() # Aseg煤rate de que la clase M5 est茅 definida aqu铆 o importada.
8
- model.load_state_dict(torch.load("modelo_entrenado.pth", map_location=device))
9
- model.to(device)
10
- model.eval()
11
-
12
- # Definir la funci贸n de inferencia
13
- def predict(audio):
14
- # Aqu铆 debes implementar la l贸gica para procesar el audio y aplicar el modelo.
15
- # Usa la misma l贸gica que la funci贸n `predict` de tu cuaderno.
16
- waveform, sample_rate = torchaudio.load(audio)
17
- # ... (resto de la l贸gica para predecir) ...
18
- return prediction
19
-
20
- # Crear la interfaz de Gradio
21
- iface = gr.Interface(
22
- fn=predict,
23
- inputs=gr.Audio(source="microphone", type="filepath"),
24
- outputs="text",
25
- title="Reconocimiento de comandos de voz",
26
- description="Graba un comando de voz y el modelo lo predecir谩."
27
- )
28
-
29
- # Lanzar la interfaz
30
- iface.launch(share=True) # share=True para crear un enlace p煤blico
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ # Definici贸n de la clase M5
8
+ class M5(nn.Module):
9
+ def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
10
+ super().__init__()
11
+ self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
12
+ self.bn1 = nn.BatchNorm1d(n_channel)
13
+ self.pool1 = nn.MaxPool1d(4)
14
+ self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
15
+ self.bn2 = nn.BatchNorm1d(n_channel)
16
+ self.pool2 = nn.MaxPool1d(4)
17
+ self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
18
+ self.bn3 = nn.BatchNorm1d(2 * n_channel)
19
+ self.pool3 = nn.MaxPool1d(4)
20
+ self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
21
+ self.bn4 = nn.BatchNorm1d(2 * n_channel)
22
+ self.pool4 = nn.MaxPool1d(4)
23
+ self.fc1 = nn.Linear(2 * n_channel, n_output)
24
+
25
+ def forward(self, x):
26
+ x = self.conv1(x)
27
+ x = F.relu(self.bn1(x))
28
+ x = self.pool1(x)
29
+ x = self.conv2(x)
30
+ x = F.relu(self.bn2(x))
31
+ x = self.pool2(x)
32
+ x = self.conv3(x)
33
+ x = F.relu(self.bn3(x))
34
+ x = self.pool3(x)
35
+ x = self.conv4(x)
36
+ x = F.relu(self.bn4(x))
37
+ x = self.pool4(x)
38
+ x = F.avg_pool1d(x, x.shape[-1])
39
+ x = x.permute(0, 2, 1)
40
+ x = self.fc1(x)
41
+ return F.log_softmax(x, dim=2)
42
+
43
+ # Definici贸n de etiquetas
44
+ labels = ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow',
45
+ 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine',
46
+ 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three',
47
+ 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']
48
+
49
+ # Funciones auxiliares
50
+ def label_to_index(word):
51
+ return torch.tensor(labels.index(word))
52
+
53
+ def index_to_label(index):
54
+ return labels[index]
55
+
56
+ def get_likely_index(tensor):
57
+ return tensor.argmax(dim=-1)
58
+
59
+ # Cargar el modelo
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ model = M5()
62
+ model.load_state_dict(torch.load("modelo_entrenado.pth", map_location=device))
63
+ model.to(device)
64
+ model.eval()
65
+
66
+ # Definir la funci贸n de inferencia
67
+ def predict(audio):
68
+ waveform, sample_rate = torchaudio.load(audio)
69
+ transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=8000).to(device)
70
+ waveform = waveform.to(device)
71
+ waveform = transform(waveform)
72
+
73
+ with torch.no_grad():
74
+ output = model(waveform.unsqueeze(0))
75
+ tensor = get_likely_index(output)
76
+ prediction = index_to_label(tensor.squeeze())
77
+ return prediction
78
+
79
+ # Crear la interfaz de Gradio
80
+ iface = gr.Interface(
81
+ fn=predict,
82
+ inputs=gr.Audio(source="microphone", type="filepath"),
83
+ outputs="text",
84
+ title="Reconocimiento de comandos de voz",
85
+ description="Graba un comando de voz y el modelo lo predecir谩."
86
+ )
87
+
88
+ # Lanzar la interfaz
89
+ iface.launch(share=True)