Prathamesh Sarjerao Vaidya
completed the project
3e27995
"""
Noise Reduction Module for PS-6 Requirements
This module provides speech enhancement capabilities to handle noisy audio
conditions as required for SNR -5 to 20 dB operation.
"""
import numpy as np
import torch
import torchaudio
from typing import Optional, Tuple
import logging
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
class NoiseReducer:
"""
Speech enhancement system for noise reduction and robustness.
Handles various noise conditions to improve ASR performance.
"""
def __init__(self, device: str = "cpu", cache_dir: str = "./model_cache"):
self.device = device
self.cache_dir = Path(cache_dir)
self.enhancement_model = None
self.sample_rate = 16000
# Initialize noise reduction model
self._initialize_model()
def _initialize_model(self):
"""Initialize advanced speech enhancement models."""
try:
# Try to load multiple advanced speech enhancement models
models_to_try = [
"speechbrain/sepformer-wham",
"speechbrain/sepformer-wsj02mix",
"facebook/demucs",
"microsoft/DialoGPT-medium" # For conversational context
]
self.enhancement_models = {}
for model_name in models_to_try:
try:
if "speechbrain" in model_name:
from speechbrain.pretrained import SepformerSeparation
self.enhancement_models[model_name] = SepformerSeparation.from_hparams(
source=model_name,
savedir=f"{self.cache_dir}/speechbrain_enhancement/{model_name.split('/')[-1]}",
run_opts={"device": self.device}
)
logger.info(f"Loaded SpeechBrain enhancement model: {model_name}")
elif "demucs" in model_name:
# Try to load Demucs for music/speech separation
try:
import demucs.api
self.enhancement_models[model_name] = demucs.api.Separator()
logger.info(f"Loaded Demucs model: {model_name}")
except ImportError:
logger.warning("Demucs not available, skipping")
except Exception as model_error:
logger.warning(f"Failed to load {model_name}: {model_error}")
continue
if not self.enhancement_models:
logger.info("No advanced models loaded, using enhanced signal processing")
self.enhancement_models = None
else:
logger.info(f"Loaded {len(self.enhancement_models)} enhancement models")
except Exception as e:
logger.warning(f"Could not load advanced noise reduction models: {e}")
logger.info("Using enhanced signal processing for noise reduction")
self.enhancement_models = None
def enhance_audio(self, audio_path: str, output_path: Optional[str] = None) -> str:
"""
Enhance audio using advanced noise reduction and speech enhancement.
Args:
audio_path: Path to input audio file
output_path: Path for enhanced audio output (optional)
Returns:
Path to enhanced audio file
"""
try:
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample if necessary
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
waveform = resampler(waveform)
# Apply advanced noise reduction
enhanced_waveform = self._apply_advanced_noise_reduction(waveform, audio_path)
# Generate output path if not provided
if output_path is None:
input_path = Path(audio_path)
output_path = input_path.parent / f"{input_path.stem}_enhanced{input_path.suffix}"
# Save enhanced audio
torchaudio.save(output_path, enhanced_waveform, self.sample_rate)
logger.info(f"Audio enhanced using advanced methods and saved to: {output_path}")
return str(output_path)
except Exception as e:
logger.error(f"Error enhancing audio: {e}")
return audio_path # Return original path if enhancement fails
def _apply_advanced_noise_reduction(self, waveform: torch.Tensor, audio_path: str) -> torch.Tensor:
"""
Apply advanced noise reduction techniques to the waveform.
Args:
waveform: Input audio waveform
audio_path: Path to audio file for context
Returns:
Enhanced waveform
"""
try:
# First try advanced models if available
if self.enhancement_models:
enhanced_waveform = self._apply_ml_enhancement(waveform)
if enhanced_waveform is not None:
return enhanced_waveform
# Fallback to enhanced signal processing
return self._apply_enhanced_signal_processing(waveform)
except Exception as e:
logger.error(f"Error in advanced noise reduction: {e}")
return waveform # Return original if processing fails
def _apply_ml_enhancement(self, waveform: torch.Tensor) -> Optional[torch.Tensor]:
"""Apply machine learning-based enhancement models."""
try:
audio = waveform.squeeze().numpy()
for model_name, model in self.enhancement_models.items():
try:
if "speechbrain" in model_name:
# Use SpeechBrain Sepformer for speech enhancement
enhanced_audio = model.separate_batch(waveform.unsqueeze(0))
if enhanced_audio is not None and len(enhanced_audio) > 0:
return enhanced_audio[0, 0, :].unsqueeze(0) # Take first source
elif "demucs" in model_name:
# Use Demucs for source separation
import demucs.api
separated = model.separate_tensor(waveform)
if separated is not None and len(separated) > 0:
return separated[0] # Take first separated source
except Exception as model_error:
logger.warning(f"Error with {model_name}: {model_error}")
continue
return None
except Exception as e:
logger.error(f"Error in ML enhancement: {e}")
return None
def _apply_enhanced_signal_processing(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Apply enhanced signal processing techniques for advanced performance.
Args:
waveform: Input audio waveform
Returns:
Enhanced waveform
"""
try:
# Convert to numpy for processing
audio = waveform.squeeze().numpy()
# Apply multiple enhancement techniques in sequence
enhanced_audio = self._advanced_spectral_subtraction(audio)
enhanced_audio = self._adaptive_wiener_filtering(enhanced_audio)
enhanced_audio = self._kalman_filtering(enhanced_audio)
enhanced_audio = self._non_local_means_denoising(enhanced_audio)
enhanced_audio = self._wavelet_denoising(enhanced_audio)
# Convert back to tensor
enhanced_waveform = torch.from_numpy(enhanced_audio).unsqueeze(0)
return enhanced_waveform
except Exception as e:
logger.error(f"Error in enhanced signal processing: {e}")
return waveform # Return original if processing fails
def _apply_noise_reduction(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Apply basic noise reduction techniques to the waveform.
Args:
waveform: Input audio waveform
Returns:
Enhanced waveform
"""
try:
# Convert to numpy for processing
audio = waveform.squeeze().numpy()
# Apply various enhancement techniques
enhanced_audio = self._spectral_subtraction(audio)
enhanced_audio = self._wiener_filtering(enhanced_audio)
enhanced_audio = self._adaptive_filtering(enhanced_audio)
# Convert back to tensor
enhanced_waveform = torch.from_numpy(enhanced_audio).unsqueeze(0)
return enhanced_waveform
except Exception as e:
logger.error(f"Error in noise reduction: {e}")
return waveform # Return original if processing fails
def _spectral_subtraction(self, audio: np.ndarray) -> np.ndarray:
"""
Apply spectral subtraction for noise reduction.
Args:
audio: Input audio signal
Returns:
Enhanced audio signal
"""
try:
# Compute STFT
stft = np.fft.fft(audio)
magnitude = np.abs(stft)
phase = np.angle(stft)
# Estimate noise from first few frames (assuming they contain mostly noise)
noise_frames = min(10, len(magnitude) // 4)
noise_spectrum = np.mean(magnitude[:noise_frames])
# Apply spectral subtraction
alpha = 2.0 # Over-subtraction factor
beta = 0.01 # Spectral floor factor
enhanced_magnitude = magnitude - alpha * noise_spectrum
enhanced_magnitude = np.maximum(enhanced_magnitude, beta * magnitude)
# Reconstruct signal
enhanced_stft = enhanced_magnitude * np.exp(1j * phase)
enhanced_audio = np.real(np.fft.ifft(enhanced_stft))
return enhanced_audio
except Exception as e:
logger.error(f"Error in spectral subtraction: {e}")
return audio
def _wiener_filtering(self, audio: np.ndarray) -> np.ndarray:
"""
Apply Wiener filtering for noise reduction.
Args:
audio: Input audio signal
Returns:
Enhanced audio signal
"""
try:
# Simple Wiener filter implementation
# In practice, you would use more sophisticated methods
# Apply a simple high-pass filter to remove low-frequency noise
from scipy import signal
# Design high-pass filter
nyquist = self.sample_rate / 2
cutoff = 80 # Hz
normalized_cutoff = cutoff / nyquist
b, a = signal.butter(4, normalized_cutoff, btype='high', analog=False)
filtered_audio = signal.filtfilt(b, a, audio)
return filtered_audio
except Exception as e:
logger.error(f"Error in Wiener filtering: {e}")
return audio
def _adaptive_filtering(self, audio: np.ndarray) -> np.ndarray:
"""
Apply adaptive filtering for noise reduction.
Args:
audio: Input audio signal
Returns:
Enhanced audio signal
"""
try:
# Simple adaptive filtering using moving average
window_size = int(0.025 * self.sample_rate) # 25ms window
# Apply moving average filter
filtered_audio = np.convolve(audio, np.ones(window_size)/window_size, mode='same')
# Mix original and filtered signal
alpha = 0.7 # Mixing factor
enhanced_audio = alpha * audio + (1 - alpha) * filtered_audio
return enhanced_audio
except Exception as e:
logger.error(f"Error in adaptive filtering: {e}")
return audio
def estimate_snr(self, audio_path: str) -> float:
"""
Estimate Signal-to-Noise Ratio of the audio.
Args:
audio_path: Path to audio file
Returns:
Estimated SNR in dB
"""
try:
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Convert to mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
audio = waveform.squeeze().numpy()
# Estimate signal power (using RMS)
signal_power = np.mean(audio ** 2)
# Estimate noise power (using quiet segments)
# Find quiet segments (low energy)
frame_length = int(0.025 * sample_rate) # 25ms frames
hop_length = int(0.010 * sample_rate) # 10ms hop
frame_energies = []
for i in range(0, len(audio) - frame_length, hop_length):
frame = audio[i:i + frame_length]
energy = np.mean(frame ** 2)
frame_energies.append(energy)
# Use bottom 10% of frames as noise estimate
frame_energies = np.array(frame_energies)
noise_threshold = np.percentile(frame_energies, 10)
noise_power = np.mean(frame_energies[frame_energies <= noise_threshold])
# Calculate SNR
if noise_power > 0:
snr_db = 10 * np.log10(signal_power / noise_power)
else:
snr_db = 50 # Very high SNR if no noise detected
return float(snr_db)
except Exception as e:
logger.error(f"Error estimating SNR: {e}")
return 20.0 # Default SNR estimate
def is_noisy_audio(self, audio_path: str, threshold: float = 15.0) -> bool:
"""
Determine if audio is noisy based on SNR estimation.
Args:
audio_path: Path to audio file
threshold: SNR threshold in dB (below this is considered noisy)
Returns:
True if audio is considered noisy
"""
try:
snr = self.estimate_snr(audio_path)
return snr < threshold
except Exception as e:
logger.error(f"Error checking if audio is noisy: {e}")
return False
def get_enhancement_stats(self, original_path: str, enhanced_path: str) -> dict:
"""
Get statistics comparing original and enhanced audio.
Args:
original_path: Path to original audio
enhanced_path: Path to enhanced audio
Returns:
Dictionary with enhancement statistics
"""
try:
original_snr = self.estimate_snr(original_path)
enhanced_snr = self.estimate_snr(enhanced_path)
return {
'original_snr': original_snr,
'enhanced_snr': enhanced_snr,
'snr_improvement': enhanced_snr - original_snr,
'enhancement_applied': True
}
except Exception as e:
logger.error(f"Error getting enhancement stats: {e}")
return {
'original_snr': 0.0,
'enhanced_snr': 0.0,
'snr_improvement': 0.0,
'enhancement_applied': False,
'error': str(e)
}
def _advanced_spectral_subtraction(self, audio: np.ndarray) -> np.ndarray:
"""Advanced spectral subtraction with adaptive parameters."""
try:
# Compute STFT with overlap
hop_length = 512
n_fft = 2048
stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
magnitude = np.abs(stft)
phase = np.angle(stft)
# Adaptive noise estimation
noise_frames = min(20, len(magnitude[0]) // 4)
noise_spectrum = np.mean(magnitude[:, :noise_frames], axis=1, keepdims=True)
# Adaptive over-subtraction factor based on SNR
snr_estimate = np.mean(magnitude) / (np.mean(noise_spectrum) + 1e-10)
alpha = max(1.5, min(3.0, 2.0 + 0.5 * (20 - snr_estimate) / 20))
# Apply spectral subtraction
enhanced_magnitude = magnitude - alpha * noise_spectrum
enhanced_magnitude = np.maximum(enhanced_magnitude, 0.01 * magnitude)
# Reconstruct signal
enhanced_stft = enhanced_magnitude * np.exp(1j * phase)
enhanced_audio = librosa.istft(enhanced_stft, hop_length=hop_length)
return enhanced_audio
except Exception as e:
logger.error(f"Error in advanced spectral subtraction: {e}")
return audio
def _adaptive_wiener_filtering(self, audio: np.ndarray) -> np.ndarray:
"""Adaptive Wiener filtering with frequency-dependent parameters."""
try:
from scipy import signal
# Design adaptive filter based on signal characteristics
nyquist = self.sample_rate / 2
# Adaptive cutoff based on signal energy distribution
f, psd = signal.welch(audio, self.sample_rate, nperseg=1024)
energy_80_percent = np.cumsum(psd) / np.sum(psd)
cutoff_idx = np.where(energy_80_percent >= 0.8)[0][0]
adaptive_cutoff = f[cutoff_idx]
# Ensure cutoff is within reasonable bounds
cutoff = max(80, min(adaptive_cutoff, 8000))
normalized_cutoff = cutoff / nyquist
# Design Butterworth filter
b, a = signal.butter(6, normalized_cutoff, btype='high', analog=False)
filtered_audio = signal.filtfilt(b, a, audio)
return filtered_audio
except Exception as e:
logger.error(f"Error in adaptive Wiener filtering: {e}")
return audio
def _kalman_filtering(self, audio: np.ndarray) -> np.ndarray:
"""Kalman filtering for noise reduction."""
try:
# Simple Kalman filter implementation
# State: [signal, derivative]
# Measurement: current sample
# Initialize Kalman filter parameters
dt = 1.0 / self.sample_rate
A = np.array([[1, dt], [0, 1]]) # State transition matrix
H = np.array([[1, 0]]) # Observation matrix
Q = np.array([[0.1, 0], [0, 0.1]]) # Process noise covariance
R = np.array([[0.5]]) # Measurement noise covariance
# Initialize state and covariance
x = np.array([[audio[0]], [0]]) # Initial state
P = np.eye(2) # Initial covariance
filtered_audio = np.zeros_like(audio)
filtered_audio[0] = audio[0]
for i in range(1, len(audio)):
# Predict
x_pred = A @ x
P_pred = A @ P @ A.T + Q
# Update
y = audio[i] - H @ x_pred
S = H @ P_pred @ H.T + R
K = P_pred @ H.T @ np.linalg.inv(S)
x = x_pred + K @ y
P = (np.eye(2) - K @ H) @ P_pred
filtered_audio[i] = x[0, 0]
return filtered_audio
except Exception as e:
logger.error(f"Error in Kalman filtering: {e}")
return audio
def _non_local_means_denoising(self, audio: np.ndarray) -> np.ndarray:
"""Non-local means denoising for audio."""
try:
# Simplified non-local means for 1D audio signal
window_size = 5
search_size = 11
h = 0.1 # Filtering parameter
denoised = np.zeros_like(audio)
for i in range(len(audio)):
# Define search window
start = max(0, i - search_size // 2)
end = min(len(audio), i + search_size // 2 + 1)
weights = []
values = []
for j in range(start, end):
# Calculate similarity between patches
patch_i_start = max(0, i - window_size // 2)
patch_i_end = min(len(audio), i + window_size // 2 + 1)
patch_j_start = max(0, j - window_size // 2)
patch_j_end = min(len(audio), j + window_size // 2 + 1)
patch_i = audio[patch_i_start:patch_i_end]
patch_j = audio[patch_j_start:patch_j_end]
# Ensure patches are same size
min_len = min(len(patch_i), len(patch_j))
patch_i = patch_i[:min_len]
patch_j = patch_j[:min_len]
# Calculate distance
distance = np.sum((patch_i - patch_j) ** 2) / len(patch_i)
weight = np.exp(-distance / (h ** 2))
weights.append(weight)
values.append(audio[j])
# Weighted average
if weights:
weights = np.array(weights)
values = np.array(values)
denoised[i] = np.sum(weights * values) / np.sum(weights)
else:
denoised[i] = audio[i]
return denoised
except Exception as e:
logger.error(f"Error in non-local means denoising: {e}")
return audio
def _wavelet_denoising(self, audio: np.ndarray) -> np.ndarray:
"""Wavelet-based denoising."""
try:
import pywt
# Choose wavelet and decomposition level
wavelet = 'db4'
level = 4
# Decompose signal
coeffs = pywt.wavedec(audio, wavelet, level=level)
# Estimate noise level using median absolute deviation
sigma = np.median(np.abs(coeffs[-1])) / 0.6745
# Apply soft thresholding
threshold = sigma * np.sqrt(2 * np.log(len(audio)))
coeffs_thresh = [pywt.threshold(c, threshold, mode='soft') for c in coeffs]
# Reconstruct signal
denoised_audio = pywt.waverec(coeffs_thresh, wavelet)
# Ensure same length
if len(denoised_audio) != len(audio):
denoised_audio = denoised_audio[:len(audio)]
return denoised_audio
except Exception as e:
logger.error(f"Error in wavelet denoising: {e}")
return audio