Jose Jorge Muñoz commited on
Commit
0dfa3b6
·
verified ·
1 Parent(s): 3b765bd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -0
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from speechbrain.pretrained import EncoderClassifier
5
+ import numpy as np
6
+ from scipy.spatial.distance import cosine
7
+ import librosa
8
+ import torchaudio
9
+ import gradio as gr
10
+ import noisereduce as nr # Ensure this package is installed (e.g., via pip install noisereduce)
11
+
12
+ # Import WavLM components from Hugging Face
13
+ from transformers import WavLMForXVector, Wav2Vec2FeatureExtractor
14
+
15
+ # ---------------- Noise Reduction and Silence Removal Functions ----------------
16
+ def reduce_noise(waveform, sample_rate=16000):
17
+ """
18
+ Apply a mild noise reduction to the waveform specialized for voice audio.
19
+ The parameters are chosen to minimize alteration to the original voice.
20
+
21
+ Parameters:
22
+ waveform (torch.Tensor): Audio tensor of shape (1, n_samples)
23
+ sample_rate (int): Sampling rate of the audio
24
+
25
+ Returns:
26
+ torch.Tensor: Denoised audio tensor of shape (1, n_samples)
27
+ """
28
+ # Convert tensor to numpy array
29
+ waveform_np = waveform.squeeze(0).cpu().numpy()
30
+ # Perform noise reduction with conservative parameters.
31
+ reduced_noise = nr.reduce_noise(y=waveform_np, sr=sample_rate, prop_decrease=0.5)
32
+ return torch.from_numpy(reduced_noise).unsqueeze(0)
33
+
34
+ def remove_long_silence(waveform, sample_rate=16000, top_db=20, max_silence_length=1.0):
35
+ """
36
+ Remove silence segments longer than max_silence_length seconds from the audio.
37
+ This function uses librosa.effects.split to detect non-silent intervals and
38
+ preserves at most max_silence_length seconds of silence between speech segments.
39
+
40
+ Parameters:
41
+ waveform (torch.Tensor): Audio tensor of shape (1, n_samples)
42
+ sample_rate (int): Sampling rate of the audio
43
+ top_db (int): The threshold (in decibels) below reference to consider as silence
44
+ max_silence_length (float): Maximum allowed silence duration in seconds
45
+
46
+ Returns:
47
+ torch.Tensor: Processed audio tensor with long silences removed
48
+ """
49
+ # Convert tensor to numpy array
50
+ waveform_np = waveform.squeeze(0).cpu().numpy()
51
+ # Identify non-silent intervals
52
+ non_silent_intervals = librosa.effects.split(waveform_np, top_db=top_db)
53
+ if len(non_silent_intervals) == 0:
54
+ return waveform
55
+
56
+ output_segments = []
57
+ max_silence_samples = int(max_silence_length * sample_rate)
58
+
59
+ # Handle silence before the first non-silent interval
60
+ if non_silent_intervals[0][0] > 0:
61
+ output_segments.append(waveform_np[:min(non_silent_intervals[0][0], max_silence_samples)])
62
+
63
+ # Process each non-silent interval and the gap following it
64
+ for i, (start, end) in enumerate(non_silent_intervals):
65
+ output_segments.append(waveform_np[start:end])
66
+ if i < len(non_silent_intervals) - 1:
67
+ next_start = non_silent_intervals[i + 1][0]
68
+ gap = next_start - end
69
+ if gap > max_silence_samples:
70
+ output_segments.append(waveform_np[end:end + max_silence_samples])
71
+ else:
72
+ output_segments.append(waveform_np[end:next_start])
73
+
74
+ # Handle silence after the last non-silent interval
75
+ if non_silent_intervals[-1][1] < len(waveform_np):
76
+ gap = len(waveform_np) - non_silent_intervals[-1][1]
77
+ if gap > max_silence_samples:
78
+ output_segments.append(waveform_np[-max_silence_samples:])
79
+ else:
80
+ output_segments.append(waveform_np[non_silent_intervals[-1][1]:])
81
+
82
+ processed_waveform = np.concatenate(output_segments)
83
+ return torch.from_numpy(processed_waveform).unsqueeze(0)
84
+ # -----------------------------------------------------------------------------
85
+
86
+ class EnhancedECAPATDNN(nn.Module):
87
+ def __init__(self):
88
+ super().__init__()
89
+ # Primary pretrained model from SpeechBrain (ECAPA-TDNN, trained on VoxCeleb)
90
+ self.ecapa = EncoderClassifier.from_hparams(
91
+ source="speechbrain/spkrec-ecapa-voxceleb",
92
+ savedir="pretrained_models/spkrec-ecapa-voxceleb",
93
+ run_opts={"device": "cuda" if torch.cuda.is_available() else "cpu"}
94
+ )
95
+
96
+ # Secondary pretrained model: Microsoft WavLM for Speaker Verification
97
+ self.wavlm_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-sv")
98
+ self.wavlm = WavLMForXVector.from_pretrained("microsoft/wavlm-base-sv")
99
+ self.wavlm.to("cuda" if torch.cuda.is_available() else "cpu")
100
+
101
+ # Projection layer to map WavLM's embedding (now 512-dim) to 192-dim (to match ECAPA)
102
+ self.wavlm_proj = nn.Linear(512, 192)
103
+
104
+ # Enhanced network: deeper enhancement layers
105
+ # Increase dimensionality then reduce back to 192.
106
+ self.enhancement = nn.Sequential(
107
+ nn.Linear(192, 256),
108
+ nn.ReLU(),
109
+ nn.Dropout(0.3),
110
+ nn.Linear(256, 192)
111
+ )
112
+
113
+ # Transformer encoder block (with batch_first=True)
114
+ self.transformer = nn.TransformerEncoder(
115
+ nn.TransformerEncoderLayer(d_model=192, nhead=4, dropout=0.3, batch_first=True),
116
+ num_layers=2
117
+ )
118
+
119
+ @torch.no_grad()
120
+ def forward(self, x):
121
+ """
122
+ x: input waveform tensor of shape (1, T) on device.
123
+ """
124
+ # Extract ECAPA embedding
125
+ emb_ecapa = self.ecapa.encode_batch(x)
126
+
127
+ # Prepare input for WavLM:
128
+ # x is a waveform tensor of shape (1, T)
129
+ waveform_np = x.squeeze(0).cpu().numpy() # shape (T,)
130
+ wavlm_inputs = self.wavlm_feature_extractor(waveform_np, sampling_rate=16000, return_tensors="pt")
131
+ wavlm_inputs = {k: v.to(x.device) for k, v in wavlm_inputs.items()}
132
+ wavlm_out = self.wavlm(**wavlm_inputs)
133
+ # Extract embeddings; expected shape (batch, 512)
134
+ emb_wavlm = wavlm_out.embeddings
135
+ # Project WavLM embedding to 192-dim
136
+ emb_wavlm_proj = self.wavlm_proj(emb_wavlm)
137
+
138
+ # Process ECAPA embedding:
139
+ if emb_ecapa.dim() > 2 and emb_ecapa.size(1) > 1:
140
+ emb_ecapa_proc = self.transformer(emb_ecapa)
141
+ emb_ecapa_proc = emb_ecapa_proc.mean(dim=1)
142
+ else:
143
+ emb_ecapa_proc = emb_ecapa
144
+
145
+ # Fuse the two embeddings by averaging
146
+ fused = (emb_ecapa_proc + emb_wavlm_proj) / 2
147
+
148
+ # Apply enhancement layers and normalize
149
+ enhanced = self.enhancement(fused)
150
+ output = F.normalize(enhanced, p=2, dim=-1)
151
+ return output
152
+
153
+ class ForensicSpeakerVerification:
154
+ def __init__(self):
155
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ print(f"Using device: {self.device}")
157
+ self.model = EnhancedECAPATDNN().to(self.device)
158
+ self.model.eval()
159
+
160
+ # Optimize only the enhancement and transformer layers if fine-tuning
161
+ trainable_params = list(self.model.enhancement.parameters()) + list(self.model.transformer.parameters())
162
+ self.optimizer = torch.optim.AdamW(trainable_params, lr=1e-4)
163
+ self.training_embeddings = []
164
+
165
+ def preprocess_audio(self, file_path, max_duration=10):
166
+ try:
167
+ waveform, sample_rate = torchaudio.load(file_path)
168
+ if waveform.shape[0] > 1:
169
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
170
+ if sample_rate != 16000:
171
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
172
+ waveform = resampler(waveform)
173
+ max_length = int(16000 * max_duration)
174
+ if waveform.shape[1] > max_length:
175
+ waveform = waveform[:, :max_length]
176
+ waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
177
+ # Apply noise reduction
178
+ waveform = reduce_noise(waveform, sample_rate=16000)
179
+ # Remove silences longer than 1 second
180
+ waveform = remove_long_silence(waveform, sample_rate=16000)
181
+ return waveform.to(self.device)
182
+ except Exception as e:
183
+ raise ValueError(f"Error preprocessing audio: {str(e)}")
184
+
185
+ @torch.no_grad()
186
+ def extract_embedding(self, file_path, chunk_duration=3, overlap=0.5):
187
+ waveform = self.preprocess_audio(file_path)
188
+ sample_rate = 16000
189
+ chunk_size = int(chunk_duration * sample_rate)
190
+ hop_size = int(chunk_size * (1 - overlap))
191
+ embeddings = []
192
+ if waveform.shape[1] > chunk_size:
193
+ for start in range(0, waveform.shape[1] - chunk_size + 1, hop_size):
194
+ chunk = waveform[:, start:start+chunk_size]
195
+ emb = self.model(chunk)
196
+ embeddings.append(emb)
197
+ final_emb = torch.mean(torch.cat(embeddings, dim=0), dim=0, keepdim=True)
198
+ else:
199
+ final_emb = self.model(waveform)
200
+ return final_emb.cpu().numpy()
201
+
202
+ def verify_speaker(self, questioned_audio, suspect_audio, progress=gr.Progress()):
203
+ if not questioned_audio or not suspect_audio:
204
+ return "⚠️ Please provide both audio samples"
205
+ try:
206
+ progress(0.2, desc="Processing questioned audio...")
207
+ questioned_emb = self.extract_embedding(questioned_audio)
208
+ progress(0.4, desc="Processing suspect audio...")
209
+ suspect_emb = self.extract_embedding(suspect_audio)
210
+ progress(0.6, desc="Computing similarity...")
211
+ score = 1 - cosine(questioned_emb.flatten(), suspect_emb.flatten())
212
+
213
+ # Convert similarity score to probability (percentage)
214
+ probability = score * 100
215
+
216
+ # Create heat bar HTML
217
+ heat_bar = f"""
218
+ <div style="width:100%; height:30px; position:relative; margin-bottom:10px;">
219
+ <div style="width:100%; height:20px; background: linear-gradient(to right, #FF0000, #FFFF00, #00FF00); border-radius:10px;"></div>
220
+ <div style="position:absolute; left:{probability}%; top:0; transform:translateX(-50%);">
221
+ <div style="width:0; height:0; border-left:8px solid transparent; border-right:8px solid transparent; border-bottom:10px solid black;"></div>
222
+ <div style="width:2px; height:20px; background-color:black; margin-left:7px;"></div>
223
+ </div>
224
+ </div>
225
+ """
226
+
227
+ # Determine color based on probability
228
+ if probability <= 50:
229
+ color = f"rgb(255, {int(255 * (probability / 50))}, 0)"
230
+ else:
231
+ color = f"rgb({int(255 * (2 - probability / 50))}, 255, 0)"
232
+
233
+ # Determine verdict text
234
+ if score >= 0.6:
235
+ verdict_text = '✅ Same Speaker'
236
+ else:
237
+ verdict_text = '⚠️ Different Speakers'
238
+
239
+ result = f"""
240
+ <div style='font-family: Arial, sans-serif; font-size: 16px; background-color: #f5f5f5; padding: 20px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);'>
241
+ <h2 style='color: #333; margin-bottom: 20px;'>Speaker Verification Analysis Results</h2>
242
+ <p style='margin-bottom: 10px; color: black;'>Similarity Score: <strong style='color:{color};'>{probability:.1f}%</strong></p>
243
+ {heat_bar}
244
+ <p style='margin-top: 20px; font-size: 18px; font-weight: bold; color: #333;'>{verdict_text}</p>
245
+ </div>
246
+ """
247
+ progress(1.0)
248
+ return result
249
+ except Exception as e:
250
+ return f"❌ Error during verification: {str(e)}"
251
+
252
+ # Initialize the system
253
+ speaker_verification = ForensicSpeakerVerification()
254
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
255
+ gr.Markdown(
256
+ """
257
+ # 🎙️ Forensic Speaker Verification System
258
+ Upload or record two audio samples to compare and verify if they belong to the same speaker.
259
+ """
260
+ )
261
+
262
+ with gr.Column():
263
+ questioned_audio = gr.Audio(
264
+ sources=["upload", "microphone"],
265
+ type="filepath",
266
+ label="Questioned Audio Sample"
267
+ )
268
+ suspect_audio = gr.Audio(
269
+ sources=["upload", "microphone"],
270
+ type="filepath",
271
+ label="Suspect Audio Sample"
272
+ )
273
+ test_button = gr.Button("🔍 Compare Speakers", variant="primary")
274
+ test_output = gr.HTML()
275
+
276
+ test_button.click(
277
+ fn=speaker_verification.verify_speaker,
278
+ inputs=[questioned_audio, suspect_audio],
279
+ outputs=test_output
280
+ )
281
+
282
+ gr.Markdown(
283
+ """
284
+ ### How it works
285
+ 1. Upload or record the questioned audio sample.
286
+ 2. Upload or record the suspect audio sample.
287
+ 3. Click "Compare Speakers" to analyze the similarity between the two samples.
288
+ 4. View the results, including the similarity score and verdict.
289
+
290
+ Note: For best results, use clear audio samples with minimal background noise.
291
+ """
292
+ )
293
+
294
+ # Launch the interface
295
+ demo.launch(share=True)