Spaces:
Running
on
Zero
Running
on
Zero
from pydub import AudioSegment, silence | |
import tempfile | |
import hashlib | |
import matplotlib.pylab as plt | |
import librosa | |
from transformers import pipeline | |
import re | |
import torch | |
import numpy as np | |
import os | |
from scipy.io import wavfile | |
from scipy.signal import resample_poly | |
_ref_audio_cache = {} | |
asr_pipe = None | |
def resample_to_24khz(input_path: str, output_path: str): | |
""" | |
Resample WAV audio file to 24,000 Hz using scipy. | |
Parameters: | |
- input_path (str): Path to the input WAV file. | |
- output_path (str): Path to save the output WAV file. | |
""" | |
# Load WAV file | |
orig_sr, audio = wavfile.read(input_path) | |
# Convert to mono if stereo | |
if len(audio.shape) == 2: | |
audio = audio.mean(axis=1) | |
# Convert to float32 for processing | |
if audio.dtype != np.float32: | |
audio = audio.astype(np.float32) / np.iinfo(audio.dtype).max | |
# Resample | |
target_sr = 24000 | |
resampled = resample_poly(audio, target_sr, orig_sr) | |
# Convert back to int16 for saving | |
resampled_int16 = (resampled * 32767).astype(np.int16) | |
# Save output | |
wavfile.write(output_path, target_sr, resampled_int16) | |
def chunk_text(text, max_chars=135): | |
# print(text) | |
# Bước 1: Tách câu theo dấu ". " | |
sentences = [s.strip() for s in text.split('. ') if s.strip()] | |
# Ghép câu ngắn hơn 4 từ với câu liền kề | |
i = 0 | |
while i < len(sentences): | |
if len(sentences[i].split()) < 4: | |
if i == 0 and i + 1 < len(sentences): | |
# Ghép với câu sau | |
sentences[i + 1] = sentences[i] + ', ' + sentences[i + 1] | |
del sentences[i] | |
else: | |
if i - 1 >= 0: | |
# Ghép với câu trước | |
sentences[i - 1] = sentences[i - 1] + ', ' + sentences[i] | |
del sentences[i] | |
i -= 1 | |
else: | |
i += 1 | |
# print(sentences) | |
# Bước 2: Tách phần quá dài trong câu theo dấu ", " | |
final_sentences = [] | |
for sentence in sentences: | |
parts = [p.strip() for p in sentence.split(', ')] | |
buffer = [] | |
for part in parts: | |
buffer.append(part) | |
total_words = sum(len(p.split()) for p in buffer) | |
if total_words > 20: | |
# Tách câu ra | |
long_part = ', '.join(buffer) | |
final_sentences.append(long_part) | |
buffer = [] | |
if buffer: | |
final_sentences.append(', '.join(buffer)) | |
# print(final_sentences) | |
if len(final_sentences[-1].split()) < 4 and len(final_sentences) >= 2: | |
final_sentences[-2] = final_sentences[-2] + ", " + final_sentences[-1] | |
final_sentences = final_sentences[0:-1] | |
# print(final_sentences) | |
return final_sentences | |
def initialize_asr_pipeline(device="cuda", dtype=None): | |
if dtype is None: | |
dtype = ( | |
torch.float16 | |
if "cuda" in device | |
and torch.cuda.get_device_properties(device).major >= 6 | |
and not torch.cuda.get_device_name().endswith("[ZLUDA]") | |
else torch.float32 | |
) | |
global asr_pipe | |
asr_pipe = pipeline( | |
"automatic-speech-recognition", | |
model="vinai/PhoWhisper-medium", | |
torch_dtype=dtype, | |
device=device, | |
) | |
# transcribe | |
def transcribe(ref_audio, language=None): | |
global asr_pipe | |
if asr_pipe is None: | |
initialize_asr_pipeline(device="cuda") | |
return asr_pipe( | |
ref_audio, | |
chunk_length_s=30, | |
batch_size=128, | |
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, | |
return_timestamps=False, | |
)["text"].strip() | |
def caculate_spec(audio): | |
# Compute spectrogram (Short-Time Fourier Transform) | |
stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512) | |
spectrogram = np.abs(stft) | |
# Convert to dB | |
spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max) | |
return spectrogram_db | |
def save_spectrogram(audio, path): | |
spectrogram = caculate_spec(audio) | |
plt.figure(figsize=(12, 4)) | |
plt.imshow(spectrogram, origin="lower", aspect="auto") | |
plt.colorbar() | |
plt.savefig(path) | |
plt.close() | |
def remove_silence_edges(audio, silence_threshold=-42): | |
# Remove silence from the start | |
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold) | |
audio = audio[non_silent_start_idx:] | |
# Remove silence from the end | |
non_silent_end_duration = audio.duration_seconds | |
for ms in reversed(audio): | |
if ms.dBFS > silence_threshold: | |
break | |
non_silent_end_duration -= 0.001 | |
trimmed_audio = audio[: int(non_silent_end_duration * 1000)] | |
return trimmed_audio | |
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device="cuda"): | |
show_info("Converting audio...") | |
# ref_audio_orig_converted = ref_audio_orig.replace(".wav", "_24k.wav").replace(".mp3", "_24k.mp3").replace(".m4a", "_24k.m4a").replace(".flac", "_24k.flac") | |
# resample_to_24khz(ref_audio_orig, ref_audio_orig_converted) | |
# ref_audio_orig = ref_audio_orig_converted | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
aseg = AudioSegment.from_file(ref_audio_orig) | |
if clip_short: | |
# 1. try to find long silence for clipping | |
non_silent_segs = silence.split_on_silence( | |
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 | |
) | |
non_silent_wave = AudioSegment.silent(duration=0) | |
for non_silent_seg in non_silent_segs: | |
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: | |
show_info("Audio is over 15s, clipping short. (1)") | |
break | |
non_silent_wave += non_silent_seg | |
# 2. try to find short silence for clipping if 1. failed | |
if len(non_silent_wave) > 15000: | |
non_silent_segs = silence.split_on_silence( | |
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 | |
) | |
non_silent_wave = AudioSegment.silent(duration=0) | |
for non_silent_seg in non_silent_segs: | |
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: | |
show_info("Audio is over 15s, clipping short. (2)") | |
break | |
non_silent_wave += non_silent_seg | |
aseg = non_silent_wave | |
# 3. if no proper silence found for clipping | |
if len(aseg) > 15000: | |
aseg = aseg[:15000] | |
show_info("Audio is over 15s, clipping short. (3)") | |
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) | |
aseg.export(f.name, format="wav") | |
ref_audio = f.name | |
# Compute a hash of the reference audio file | |
with open(ref_audio, "rb") as audio_file: | |
audio_data = audio_file.read() | |
audio_hash = hashlib.md5(audio_data).hexdigest() | |
if not ref_text.strip(): | |
global _ref_audio_cache | |
if audio_hash in _ref_audio_cache: | |
# Use cached asr transcription | |
show_info("Using cached reference text...") | |
ref_text = _ref_audio_cache[audio_hash] | |
else: | |
show_info("No reference text provided, transcribing reference audio...") | |
ref_text = transcribe(ref_audio) | |
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak) | |
_ref_audio_cache[audio_hash] = ref_text | |
else: | |
show_info("Using custom reference text...") | |
# Ensure ref_text ends with a proper sentence-ending punctuation | |
if not ref_text.endswith(". ") and not ref_text.endswith("。"): | |
if ref_text.endswith("."): | |
ref_text += " " | |
else: | |
ref_text += ". " | |
print("\nref_text ", ref_text) | |
return ref_audio, ref_text |