Merlintxu commited on
Commit
33704de
verified
1 Parent(s): 53f6104

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -93
app.py CHANGED
@@ -4,135 +4,244 @@ import tempfile
4
  import torch
5
  import numpy as np
6
  import datetime
 
7
  import whisper
8
- from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
9
  from pyannote.audio import Audio
10
  from pyannote.core import Segment
11
  from sklearn.cluster import AgglomerativeClustering
12
  import gradio as gr
13
  import warnings
 
14
 
15
  warnings.filterwarnings("ignore", category=UserWarning)
 
16
 
17
- # --- Configuraci贸n de Modelos (Ligeros para Spaces) ---
18
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Cargar Whisper Small (multiling眉e, m谩s ligero)
21
- print("Cargando modelo Whisper small...")
22
- whisper_model = whisper.load_model("small", device=DEVICE)
23
- print("Modelo Whisper cargado.")
24
-
25
- # Cargar modelo de embeddings de hablante
26
- print("Cargando modelo de embeddings...")
27
- embedding_model = PretrainedSpeakerEmbedding(
28
- "speechbrain/spkrec-ecapa-voxceleb",
29
- device=DEVICE
30
- )
31
  audio_processor = Audio()
32
- print("Modelos cargados.")
33
 
34
  def time(secs):
35
  return datetime.timedelta(seconds=round(secs))
36
 
37
  def convert_to_wav(input_path):
38
- """Convierte cualquier audio a WAV usando ffmpeg."""
39
  if input_path.lower().endswith('.wav'):
40
- return input_path
 
41
 
42
- output_path = input_path.rsplit('.', 1)[0] + '_converted.wav'
43
- os.system(f"ffmpeg -y -i '{input_path}' -ac 1 -ar 16000 -acodec pcm_s16le '{output_path}'")
 
 
 
 
 
 
 
 
 
 
44
  return output_path
45
 
46
  def get_duration(path):
47
  import soundfile as sf
48
- info = sf.info(path)
49
- return info.duration
 
 
 
 
 
 
 
 
 
 
50
 
51
  def segment_embedding(path, segment, duration):
52
  start = segment["start"]
53
  end = min(duration, segment["end"])
54
  clip = Segment(start, end)
55
- waveform, sample_rate = audio_processor.crop(path, clip)
56
- embedding = embedding_model(waveform[None])
57
- return embedding.cpu().detach().numpy().squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def transcribe_and_diarize(audio_file, num_speakers):
60
  """Funci贸n principal de transcripci贸n y diarizaci贸n."""
61
- status_update = ""
62
-
63
- # --- 1. Conversi贸n ---
64
- status_update += "1. Convirtiendo audio a formato WAV...\n"
65
- yield status_update, ""
66
-
67
- wav_path = convert_to_wav(audio_file)
68
-
69
- # --- 2. Duraci贸n ---
70
- status_update += "2. Obteniendo duraci贸n del audio...\n"
71
- yield status_update, ""
72
-
73
- duration = get_duration(wav_path)
74
- if duration > 30 * 60: # Limitar a 30 minutos para evitar tiempos excesivos
75
- yield status_update + "Error: El audio es demasiado largo (m谩ximo 30 minutos).\n", ""
76
- return
77
-
78
- # --- 3. Transcripci贸n ---
79
- status_update += "3. Transcribiendo audio con Whisper (modelo 'small')...\n"
80
- yield status_update, ""
81
-
82
- # Transcribir en espa帽ol
83
- result = whisper_model.transcribe(wav_path, language='es', task='transcribe', verbose=False)
84
- segments = result["segments"]
85
 
86
- if not segments:
87
- yield status_update + "Error: No se detect贸 habla en el audio.\n", ""
88
- return
 
 
 
 
89
 
90
- # --- 4. Diarizaci贸n ---
91
- status_update += "4. Preparando para diarizaci贸n...\n"
92
- yield status_update, ""
93
-
94
- # Limitar n煤mero de hablantes
95
- num_speakers = max(2, min(6, int(num_speakers))) # Entre 2 y 6
96
- num_speakers = min(num_speakers, len(segments))
97
-
98
- if len(segments) <= 1:
99
- segments[0]['speaker'] = 'HABLANTE 1'
100
- else:
101
- status_update += " -> Extrayendo embeddings de audio...\n"
102
  yield status_update, ""
103
 
104
- embeddings = np.zeros(shape=(len(segments), 192))
105
- for i, segment in enumerate(segments):
106
- embeddings[i] = segment_embedding(wav_path, segment, duration)
107
- embeddings = np.nan_to_num(embeddings)
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- status_update += " -> Agrupando hablantes...\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  yield status_update, ""
111
 
112
- # Clustering
113
- clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
114
- labels = clustering.labels_
115
- for i in range(len(segments)):
116
- segments[i]["speaker"] = f'HABLANTE {labels[i] + 1}'
117
-
118
- # --- 5. Formateo de salida ---
119
- status_update += "5. Generando transcripci贸n final...\n"
120
- yield status_update, ""
121
-
122
- output_text = ""
123
- for (i, segment) in enumerate(segments):
124
- if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
125
- if i != 0:
126
- output_text += '\n\n'
127
- output_text += f"{segment['speaker']} [{time(segment['start'])}]\n\n"
128
- output_text += segment["text"].strip() + ' '
129
 
130
- yield status_update + "隆Proceso completado!\n", output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # --- Interfaz Gradio ---
133
  with gr.Blocks(title="Diarizaci贸n de Audio en Espa帽ol") as demo:
134
  gr.Markdown("# 馃帳 Diarizaci贸n de Audio en Espa帽ol")
135
  gr.Markdown("Sube un archivo de audio (hasta 30 minutos) y obt茅n una transcripci贸n separada por hablantes. Optimizado para espa帽ol.")
 
136
 
137
  with gr.Row():
138
  with gr.Column():
@@ -140,22 +249,21 @@ with gr.Blocks(title="Diarizaci贸n de Audio en Espa帽ol") as demo:
140
  num_speakers = gr.Slider(2, 6, value=3, step=1, label="N煤mero aproximado de hablantes")
141
  run_button = gr.Button("馃殌 Iniciar Diarizaci贸n")
142
  with gr.Column():
143
- status_output = gr.Textbox(label="Estado", interactive=False, lines=10)
144
  text_output = gr.Textbox(label="Transcripci贸n con Hablantes", interactive=False, lines=20)
145
 
146
  run_button.click(
147
  fn=transcribe_and_diarize,
148
  inputs=[audio_input, num_speakers],
149
  outputs=[status_output, text_output],
150
- queue=True # Importante para procesos largos
 
151
  )
152
 
153
  gr.Markdown("---")
154
- gr.Markdown("**Nota:** Este demo usa modelos ligeros. Para audio con mucho ruido o m谩s de 10 minutos, los resultados pueden ser menos precisos.")
155
-
156
- # Para ejecutar localmente (opcional)
157
- # if __name__ == "__main__":
158
- # demo.launch()
159
 
160
  # Para Hugging Face Spaces
161
- demo.launch()
 
4
  import torch
5
  import numpy as np
6
  import datetime
7
+ import gc
8
  import whisper
 
9
  from pyannote.audio import Audio
10
  from pyannote.core import Segment
11
  from sklearn.cluster import AgglomerativeClustering
12
  import gradio as gr
13
  import warnings
14
+ from huggingface_hub import hf_hub_download
15
 
16
  warnings.filterwarnings("ignore", category=UserWarning)
17
+ warnings.filterwarnings("ignore", category=FutureWarning)
18
 
19
+ # --- Configuraci贸n de Modelos ---
20
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+ print(f"Usando dispositivo: {DEVICE}")
22
+
23
+ # --- Cargar Whisper (intentar con una versi贸n m谩s reciente si es viable) ---
24
+ WHISPER_MODEL_NAME = "small" # Empezar con 'small' para Spaces. Probar 'medium' o 'large-v3' si hay recursos.
25
+ try:
26
+ print(f"Cargando modelo Whisper '{WHISPER_MODEL_NAME}'...")
27
+ whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=DEVICE)
28
+ print(f"Modelo Whisper '{WHISPER_MODEL_NAME}' cargado exitosamente.")
29
+ except Exception as e:
30
+ print(f"Error cargando Whisper '{WHISPER_MODEL_NAME}': {e}")
31
+ print("Intentando cargar 'base' como fallback...")
32
+ WHISPER_MODEL_NAME = "base"
33
+ whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=DEVICE)
34
+ print(f"Modelo Whisper '{WHISPER_MODEL_NAME}' cargado.")
35
+
36
+ # --- Cargar modelo de embeddings de Pyannote v3.x ---
37
+ # Usar el nuevo modelo de embedding recomendado para pyannote.audio 3.x
38
+ EMBEDDING_MODEL_NAME = "pyannote/embedding"
39
+ EMBEDDING_REVISION = "main" # O especificar un commit si es necesario
40
+
41
+ try:
42
+ print(f"Cargando modelo de embeddings '{EMBEDDING_MODEL_NAME}'...")
43
+ # Importar el pipeline de embedding de pyannote v3
44
+ from pyannote.audio import Model
45
+ embedding_model = Model.from_pretrained(
46
+ EMBEDDING_MODEL_NAME,
47
+ use_auth_token=False, # No se necesita token para modelos p煤blicos
48
+ revision=EMBEDDING_REVISION
49
+ )
50
+ embedding_model.to(DEVICE)
51
+ print(f"Modelo de embeddings '{EMBEDDING_MODEL_NAME}' cargado.")
52
+ except Exception as e:
53
+ print(f"Error cargando el modelo de embeddings '{EMBEDDING_MODEL_NAME}': {e}")
54
+ print("Intentando con speechbrain como fallback...")
55
+ # Fallback al modelo SpeechBrain si el de Pyannote falla
56
+ try:
57
+ from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
58
+ embedding_model = PretrainedSpeakerEmbedding(
59
+ "speechbrain/spkrec-ecapa-voxceleb",
60
+ device=DEVICE
61
+ )
62
+ print("Modelo de embeddings 'speechbrain/spkrec-ecapa-voxceleb' cargado como fallback.")
63
+ except Exception as e_fallback:
64
+ print(f"Error cr铆tico cargando modelo de embeddings: {e_fallback}")
65
+ raise RuntimeError("No se pudo cargar ning煤n modelo de embeddings.")
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  audio_processor = Audio()
 
68
 
69
  def time(secs):
70
  return datetime.timedelta(seconds=round(secs))
71
 
72
  def convert_to_wav(input_path):
73
+ """Convierte cualquier audio a WAV mono 16kHz usando ffmpeg."""
74
  if input_path.lower().endswith('.wav'):
75
+ # Verificar si ya es mono y 16kHz podr铆a ser 煤til, pero para simplificar, convertimos siempre
76
+ pass
77
 
78
+ # Usar un nombre temporal seguro
79
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpfile:
80
+ output_path = tmpfile.name
81
+
82
+ # Comando ffmpeg para convertir a WAV mono 16kHz
83
+ cmd = f"ffmpeg -y -i '{input_path}' -ac 1 -ar 16000 -acodec pcm_s16le '{output_path}'"
84
+ print(f"Ejecutando conversi贸n: {cmd}")
85
+ os.system(cmd)
86
+
87
+ if not os.path.exists(output_path) or os.path.getsize(output_path) == 0:
88
+ raise RuntimeError("La conversi贸n a WAV fall贸 o produjo un archivo vac铆o.")
89
+
90
  return output_path
91
 
92
  def get_duration(path):
93
  import soundfile as sf
94
+ try:
95
+ info = sf.info(path)
96
+ return info.duration
97
+ except Exception as e:
98
+ print(f"Error obteniendo duraci贸n con soundfile: {e}")
99
+ # Fallback a wave (menos robusto)
100
+ import wave
101
+ import contextlib
102
+ with contextlib.closing(wave.open(path,'r')) as f:
103
+ frames = f.getnframes()
104
+ rate = f.getframerate()
105
+ return frames / float(rate)
106
 
107
  def segment_embedding(path, segment, duration):
108
  start = segment["start"]
109
  end = min(duration, segment["end"])
110
  clip = Segment(start, end)
111
+ try:
112
+ waveform, sample_rate = audio_processor.crop(path, clip)
113
+ with torch.no_grad():
114
+ # Para modelos Pyannote v3
115
+ if hasattr(embedding_model, 'encode'):
116
+ # Modelos nuevos de pyannote devuelven diccionarios
117
+ output = embedding_model.encode(waveform[None].to(DEVICE))
118
+ if isinstance(output, dict) and 'embedding' in output:
119
+ embedding = output['embedding']
120
+ else:
121
+ embedding = output
122
+ else:
123
+ # Fallback para modelos compatibles con la API antigua o SpeechBrain
124
+ embedding = embedding_model(waveform[None].to(DEVICE))
125
+
126
+ # Asegurar que el embedding sea un tensor y luego numpy
127
+ if isinstance(embedding, torch.Tensor):
128
+ return embedding.squeeze().cpu().numpy()
129
+ else:
130
+ # Para embeddings que ya son numpy (ej. SpeechBrain wrapper)
131
+ return np.squeeze(embedding)
132
+ except Exception as e:
133
+ print(f"Error extrayendo embedding para segmento {start}-{end}: {e}")
134
+ # Devolver un embedding de ceros en caso de error
135
+ return np.zeros(512) # Ajustar tama帽o si se sabe el dim del embedding
136
+
137
 
138
  def transcribe_and_diarize(audio_file, num_speakers):
139
  """Funci贸n principal de transcripci贸n y diarizaci贸n."""
140
+ temp_files = []
141
+ try:
142
+ status_update = ""
143
+
144
+ # --- 1. Conversi贸n ---
145
+ status_update += "1. Convirtiendo audio a formato WAV (16kHz, mono)...\n"
146
+ yield status_update, ""
147
+
148
+ wav_path = convert_to_wav(audio_file)
149
+ temp_files.append(wav_path) # Para limpieza posterior
150
+
151
+ # --- 2. Duraci贸n ---
152
+ status_update += "2. Obteniendo duraci贸n del audio...\n"
153
+ yield status_update, ""
154
+
155
+ duration = get_duration(wav_path)
156
+ if duration > 30 * 60: # Limitar a 30 minutos
157
+ yield status_update + "Error: El audio es demasiado largo (m谩ximo 30 minutos).\n", ""
158
+ return
 
 
 
 
 
159
 
160
+ # --- 3. Transcripci贸n ---
161
+ status_update += f"3. Transcribiendo audio con Whisper (modelo '{WHISPER_MODEL_NAME}')...\n"
162
+ yield status_update, ""
163
+
164
+ # Transcribir en espa帽ol
165
+ result = whisper_model.transcribe(wav_path, language='es', task='transcribe', verbose=False)
166
+ segments = result["segments"]
167
 
168
+ if not segments:
169
+ yield status_update + "Error: No se detect贸 habla en el audio.\n", ""
170
+ return
171
+
172
+ # --- 4. Diarizaci贸n ---
173
+ status_update += "4. Preparando para diarizaci贸n...\n"
 
 
 
 
 
 
174
  yield status_update, ""
175
 
176
+ # Limitar n煤mero de hablantes
177
+ num_speakers = max(2, min(6, int(num_speakers)))
178
+ num_speakers = min(num_speakers, len(segments))
179
+
180
+ if len(segments) <= 1:
181
+ segments[0]['speaker'] = 'HABLANTE 1'
182
+ status_update += " -> Solo se detect贸 1 segmento de habla. Asignando un hablante.\n"
183
+ else:
184
+ status_update += " -> Extrayendo embeddings de audio...\n"
185
+ yield status_update, ""
186
+
187
+ # Determinar la dimensi贸n del embedding con una muestra
188
+ sample_embedding = segment_embedding(wav_path, segments[0], duration)
189
+ embedding_dim = sample_embedding.shape[-1] if hasattr(sample_embedding, 'shape') else 512
190
+ print(f"Dimensi贸n del embedding detectada: {embedding_dim}")
191
 
192
+ embeddings = np.zeros(shape=(len(segments), embedding_dim))
193
+ for i, segment in enumerate(segments):
194
+ embeddings[i] = segment_embedding(wav_path, segment, duration)
195
+ embeddings = np.nan_to_num(embeddings)
196
+
197
+ status_update += " -> Agrupando hablantes...\n"
198
+ yield status_update, ""
199
+
200
+ # Clustering
201
+ clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
202
+ labels = clustering.labels_
203
+ for i in range(len(segments)):
204
+ segments[i]["speaker"] = f'HABLANTE {labels[i] + 1}'
205
+
206
+ # --- 5. Formateo de salida ---
207
+ status_update += "5. Generando transcripci贸n final...\n"
208
  yield status_update, ""
209
 
210
+ output_text = ""
211
+ for (i, segment) in enumerate(segments):
212
+ if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
213
+ if i != 0:
214
+ output_text += '\n\n'
215
+ output_text += f"{segment['speaker']} [{time(segment['start'])}]\n\n"
216
+ output_text += segment["text"].strip() + ' '
 
 
 
 
 
 
 
 
 
 
217
 
218
+ yield status_update + "隆Proceso completado!\n", output_text
219
+
220
+ except Exception as e:
221
+ error_msg = f"Error durante el proceso: {str(e)}"
222
+ print(error_msg)
223
+ yield f"Error: {error_msg}\n", ""
224
+ finally:
225
+ # Limpiar archivos temporales
226
+ for f in temp_files:
227
+ try:
228
+ os.remove(f)
229
+ print(f"Archivo temporal eliminado: {f}")
230
+ except OSError:
231
+ pass
232
+ # Liberar memoria GPU/CPU
233
+ if 'whisper_model' in globals():
234
+ del whisper_model
235
+ if 'embedding_model' in globals():
236
+ del embedding_model
237
+ gc.collect()
238
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
239
 
240
  # --- Interfaz Gradio ---
241
  with gr.Blocks(title="Diarizaci贸n de Audio en Espa帽ol") as demo:
242
  gr.Markdown("# 馃帳 Diarizaci贸n de Audio en Espa帽ol")
243
  gr.Markdown("Sube un archivo de audio (hasta 30 minutos) y obt茅n una transcripci贸n separada por hablantes. Optimizado para espa帽ol.")
244
+ gr.Markdown("**Nota:** Este demo usa modelos ligeros. Para audio con mucho ruido o m谩s de 10 minutos, los resultados pueden ser menos precisos.")
245
 
246
  with gr.Row():
247
  with gr.Column():
 
249
  num_speakers = gr.Slider(2, 6, value=3, step=1, label="N煤mero aproximado de hablantes")
250
  run_button = gr.Button("馃殌 Iniciar Diarizaci贸n")
251
  with gr.Column():
252
+ status_output = gr.Textbox(label="Estado", interactive=False, lines=10, max_lines=10)
253
  text_output = gr.Textbox(label="Transcripci贸n con Hablantes", interactive=False, lines=20)
254
 
255
  run_button.click(
256
  fn=transcribe_and_diarize,
257
  inputs=[audio_input, num_speakers],
258
  outputs=[status_output, text_output],
259
+ queue=True,
260
+ concurrency_limit=1 # Limitar a 1 ejecuci贸n simult谩nea para evitar sobrecarga
261
  )
262
 
263
  gr.Markdown("---")
264
+ gr.Markdown("**Modelos Usados:**\n"
265
+ "* **Transcripci贸n:** Whisper (`small`)\n"
266
+ "* **Diarizaci贸n:** Pyannote.Audio (`pyannote/embedding` o `speechbrain/spkrec-ecapa-voxceleb`)\n")
 
 
267
 
268
  # Para Hugging Face Spaces
269
+ demo.launch()