|
|
|
""" |
|
ASR with multiple reading candidates using MeCab and Sudachi. |
|
This module provides functionality to generate and score multiple reading candidates |
|
for Japanese text and use them in speech recognition. |
|
""" |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
import math |
|
import tempfile |
|
from pathlib import Path |
|
from typing import List, Tuple, Dict |
|
|
|
import fugashi |
|
import gradio as gr |
|
import librosa |
|
import numpy as np |
|
import torch |
|
from espnet.nets.scorer_interface import BatchScorerInterface |
|
from espnet2.bin.asr_inference import Speech2Text |
|
|
|
|
|
|
|
MECAB_DIC_DIR = str(Path(__file__).parent / "mecab-ipadic-neologd") |
|
AUDIO_FILES_DIR = Path(".") |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
class MecabCandidateGenerator: |
|
"""Generate reading candidates using MeCab with mecab-ipadic-neologd dictionary.""" |
|
|
|
KATAKANA_LIST = set( |
|
"ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヽヾー・" |
|
) |
|
|
|
def __init__(self): |
|
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as tmp: |
|
tmp.write(f"""dicdir = {MECAB_DIC_DIR}\n""") |
|
tmp.flush() |
|
self.tagger = fugashi.GenericTagger( |
|
f"-r {tmp.name} -d {MECAB_DIC_DIR} -Odump" |
|
) |
|
self.tagger_p = fugashi.GenericTagger( |
|
f"-r {tmp.name} -d {MECAB_DIC_DIR} -p -F 0 -E %pc" |
|
) |
|
|
|
def _normalize_text(self, text: str) -> str: |
|
"""Normalize full-width numbers and punctuation to half-width.""" |
|
|
|
replacements = { |
|
|
|
"0": "0", |
|
"1": "1", |
|
"2": "2", |
|
"3": "3", |
|
"4": "4", |
|
"5": "5", |
|
"6": "6", |
|
"7": "7", |
|
"8": "8", |
|
"9": "9", |
|
|
|
"!": "!", |
|
"?": "?", |
|
".": ".", |
|
",": ",", |
|
} |
|
|
|
for full, half in replacements.items(): |
|
text = text.replace(full, half) |
|
|
|
return text |
|
|
|
def generate_candidates(self, text: str, nbest: int = 512) -> List[List[List[str]]]: |
|
"""Generate n-best morphological analysis candidates.""" |
|
|
|
text = self._normalize_text(text) |
|
res = self.tagger.nbest(text, nbest) |
|
candidates = [] |
|
candidate = [] |
|
for line in res.split("\n"): |
|
if not line.strip(): |
|
continue |
|
fields = line.split() |
|
word = fields[1] |
|
if word == "BOS": |
|
pass |
|
elif word == "EOS": |
|
candidates.append(candidate) |
|
candidate = [] |
|
else: |
|
candidate.append(fields) |
|
return candidates |
|
|
|
def candidate_to_score(self, candidate: List[List[str]]) -> Tuple[float, List[str]]: |
|
"""Calculate score for a morphological analysis candidate.""" |
|
query = [] |
|
for morph in candidate: |
|
assert len(morph) >= 3, ( |
|
f"Expected morph to have at least 3 fields, got {len(morph)}: {morph}" |
|
) |
|
_, original_form, features, *_ = morph |
|
query.append(f"{original_form}\t{features}\n") |
|
query = "".join(query) + "EOS" |
|
result = self.tagger_p.parse(query) |
|
result = int(result.lstrip("0")) |
|
return result, [] |
|
|
|
def candidate_to_yomi( |
|
self, candidate: List[List[str]], yomi_index: int = 7 |
|
) -> Tuple[str, List[str]]: |
|
"""Convert morphological analysis candidate to reading (yomi).""" |
|
result = "" |
|
warning_messages = [] |
|
for morph in candidate: |
|
warning_message = "" |
|
if len(morph) < 2: |
|
warning_message = f"[3] Morph has less than 2 fields: {morph[1:3]}" |
|
warning_messages.append(warning_message) |
|
continue |
|
original_form = morph[1] |
|
if len(morph) < 3: |
|
warning_message = f"[3] Morph has less than 3 fields: {morph[1:3]}" |
|
morph = morph + ["*"] |
|
features = morph[2].split(",") |
|
|
|
if len(features) <= yomi_index: |
|
reading = "" |
|
if all( |
|
"ぁ" <= c <= "ん" or c in self.KATAKANA_LIST for c in original_form |
|
): |
|
reading = original_form |
|
else: |
|
if not warning_message: |
|
warning_level = ( |
|
1 |
|
if all( |
|
c in "、。!?!?「」『』【】〔〕[]〈〉《》・" |
|
for c in original_form |
|
) |
|
else 3 |
|
) |
|
warning_message = ( |
|
f"[{warning_level}] Morph has no reading: {morph[1:3]}" |
|
) |
|
warning_messages.append(warning_message) |
|
continue |
|
else: |
|
reading = features[yomi_index] |
|
i = 0 |
|
normalized_reading = "" |
|
while i < len(reading): |
|
char = reading[i] |
|
if "ァ" <= char <= "ン" or char in "ヽヾ": |
|
|
|
normalized_reading += chr(ord(char) - 96) |
|
elif char == "ー" or "ぁ" <= char <= "ん": |
|
normalized_reading += char |
|
elif char == "ヴ": |
|
if i + 1 < len(reading) and reading[i + 1] in "ァィェォ": |
|
normalized_reading += "ばびべぼ"[ |
|
"ァィェォ".index(reading[i + 1]) |
|
] |
|
i += 1 |
|
else: |
|
normalized_reading += "ぶ" |
|
else: |
|
warning_level = ( |
|
1 if char in "、。!?!?「」『』【】〔〕[]〈〉《》・" else 3 |
|
) |
|
warning_message = f"[{warning_level}] Unhandled character in reading: {morph[1:3]}" |
|
warning_messages.append(warning_message) |
|
i += 1 |
|
result += normalized_reading |
|
|
|
return result, warning_messages |
|
|
|
|
|
class CandidateScorer(BatchScorerInterface): |
|
"""Score ASR hypotheses based on allowed reading candidates.""" |
|
|
|
def __init__(self, token_list: List[str], device: str = "cpu"): |
|
super().__init__() |
|
self.token_list = token_list |
|
self.eos_id = token_list.index("<sos/eos>") |
|
self.device = device |
|
|
|
def set_candidates(self, candidates: List[str]): |
|
"""Set the allowed reading candidates.""" |
|
self.candidates = candidates |
|
|
|
def score( |
|
self, y: torch.Tensor, state, x=None |
|
): |
|
""" |
|
Score function for beam search. |
|
|
|
Args: |
|
y: prefix token sequence |
|
state: scorer state (unused) |
|
x: encoder feature (unused) |
|
|
|
Returns: |
|
scores: token scores |
|
state: updated state |
|
""" |
|
prefix = y.tolist() |
|
assert prefix[0] == self.eos_id, prefix |
|
prefix = [self.token_list[i] for i in prefix[1:]] |
|
prefix = "".join(prefix) |
|
|
|
allowed = [] |
|
for candidate in self.candidates: |
|
if candidate.startswith(prefix): |
|
remaining = candidate[len(prefix) :] |
|
if remaining: |
|
for i, token in enumerate(self.token_list): |
|
if remaining.startswith(token): |
|
allowed.append(i) |
|
else: |
|
allowed.append(self.eos_id) |
|
|
|
allowed = list(set(allowed)) |
|
vocab = len(self.token_list) |
|
scores = torch.full((vocab,), float("-inf"), device=self.device) |
|
if allowed: |
|
scores[allowed] = 0.0 |
|
else: |
|
scores[self.eos_id] = -10000 |
|
return scores, state |
|
|
|
|
|
class MultiReadingASR: |
|
"""ASR system with multiple reading candidates support.""" |
|
|
|
def __init__(self, device: str = DEVICE): |
|
"""Initialize ASR components and dictionaries.""" |
|
print("Initializing models and dictionaries...") |
|
|
|
|
|
self.mecab_generator = MecabCandidateGenerator() |
|
|
|
|
|
self.asr = self._setup_asr_model(device) |
|
|
|
def _setup_asr_model(self, device: str = "cpu") -> Speech2Text: |
|
"""Setup and configure the ASR model.""" |
|
asr = Speech2Text.from_pretrained( |
|
"reazon-research/reazonspeech-espnet-v2", |
|
lm_weight=0, |
|
device=device, |
|
nbest=10, |
|
normalize_length=True, |
|
) |
|
|
|
|
|
allowed_tokens = set( |
|
"ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちぢっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐゑをんゝゞー" |
|
) |
|
|
|
assert len(asr.asr_model.decoder.output_layer.bias) == len( |
|
asr.asr_model.token_list |
|
) |
|
assert len(asr.asr_model.ctc.ctc_lo.bias) == len(asr.asr_model.token_list) |
|
for i, token in enumerate(asr.asr_model.token_list): |
|
if len(token) == 1 and token not in allowed_tokens: |
|
asr.asr_model.decoder.output_layer.bias.data[i] -= 100.0 |
|
asr.asr_model.ctc.ctc_lo.bias.data[i] -= 100.0 |
|
|
|
return asr |
|
|
|
def process_audio_with_candidates( |
|
self, |
|
wav_file: Path, |
|
text: str, |
|
verbose: bool = False, |
|
) -> Dict: |
|
"""Process audio file with multiple reading candidates.""" |
|
|
|
wav, _ = librosa.load(wav_file, sr=16000, mono=True, dtype=np.float32) |
|
|
|
|
|
try: |
|
relative_path = wav_file.relative_to(AUDIO_FILES_DIR) |
|
file_path = str(relative_path) |
|
except ValueError: |
|
|
|
file_path = wav_file.name |
|
|
|
results = {"file": file_path, "text": text, "readings": {}} |
|
|
|
if verbose: |
|
print(f"File: {wav_file.name}") |
|
print(f"Text: {text}") |
|
|
|
warning_messages = [] |
|
|
|
|
|
mecab_candidates = self.mecab_generator.generate_candidates(text) |
|
yomi_candidates = [] |
|
yomi_candidate_to_indices = {} |
|
|
|
for i, c in enumerate(mecab_candidates): |
|
yomi, warnings = self.mecab_generator.candidate_to_yomi(c) |
|
warning_messages.extend(warnings) |
|
yomi_candidates.append(yomi) |
|
if yomi in yomi_candidate_to_indices: |
|
yomi_candidate_to_indices[yomi].append(i) |
|
else: |
|
yomi_candidate_to_indices[yomi] = [i] |
|
|
|
yomi_candidates = sorted(set(yomi_candidates)) |
|
warning_messages = sorted(set(warning_messages), key=lambda x: (-ord(x[1]), x)) |
|
results["yomi_candidates"] = yomi_candidates |
|
results["warnings"] = warning_messages |
|
|
|
if verbose: |
|
print(f"Warning messages: {warning_messages}") |
|
|
|
|
|
scorer = CandidateScorer(self.asr.asr_model.token_list, device=self.asr.device) |
|
scorer.set_candidates(yomi_candidates) |
|
self.asr.maxlenratio = -max(map(len, yomi_candidates)) - 2 |
|
self.asr.beam_search.scorers["cand"] = scorer |
|
self.asr.beam_search.full_scorers["cand"] = scorer |
|
self.asr.beam_search.weights["cand"] = 1.0 |
|
|
|
transcription = self.asr(wav) |
|
results["transcriptions"] = [] |
|
|
|
for trans, _, _, info in transcription: |
|
if info.score == -math.inf: |
|
continue |
|
candidate_indices = yomi_candidate_to_indices.get(trans) |
|
if candidate_indices is None: |
|
mecab_score = math.inf |
|
else: |
|
mecab_score = min( |
|
self.mecab_generator.candidate_to_score(mecab_candidates[i])[0] |
|
for i in candidate_indices |
|
) |
|
mecab_scale = 800 |
|
mecab_score /= mecab_scale |
|
normalized_score = info.score.item() / (len(info.yseq) - 1) |
|
|
|
results["transcriptions"].append( |
|
{ |
|
"text": trans, |
|
"normalized_score": normalized_score, |
|
"mecab_score": mecab_score, |
|
} |
|
) |
|
|
|
if verbose: |
|
print( |
|
f"Transcription: {trans} {normalized_score:.2f} {info.score.item():.2f} {mecab_score:.2f}" |
|
) |
|
|
|
transcriptions = results["transcriptions"] |
|
mean_transcription_length = sum(len(t["text"]) for t in transcriptions) / len( |
|
transcriptions |
|
) |
|
for transcription in transcriptions: |
|
adjusted_score = ( |
|
transcription["normalized_score"] * mean_transcription_length |
|
) |
|
mecab_score = transcription["mecab_score"] |
|
transcription["ensembled_score"] = ( |
|
adjusted_score * 0.2528 + mecab_score * -0.4701 |
|
) |
|
best_transcription = max(transcriptions, key=lambda t: t["ensembled_score"]) |
|
results["best_transcription"] = best_transcription |
|
results["confidence"] = 1.0 / sum( |
|
math.exp(t["ensembled_score"] - best_transcription["ensembled_score"]) |
|
for t in transcriptions |
|
) |
|
|
|
return results |
|
|
|
|
|
|
|
print("モデルを初期化しています...") |
|
asr_system = MultiReadingASR() |
|
print("初期化が完了しました。") |
|
|
|
|
|
def process_audio_gradio(text, audio_file): |
|
"""Gradio用の処理関数""" |
|
if not text or not audio_file: |
|
return "テキストと音声ファイルの両方を入力してください。", "", "" |
|
|
|
|
|
wav_path = Path(audio_file) |
|
|
|
|
|
results = asr_system.process_audio_with_candidates( |
|
wav_file=wav_path, text=text, verbose=False |
|
) |
|
|
|
|
|
best_reading = results["best_transcription"]["text"] |
|
|
|
|
|
confidence = f"{results['confidence'] * 100:.1f}%" |
|
|
|
|
|
warnings = results.get("warnings", []) |
|
if warnings: |
|
warning_text = "\n".join(warnings) |
|
else: |
|
warning_text = "警告なし" |
|
|
|
return best_reading, confidence, warning_text |
|
|
|
|
|
|
|
with gr.Blocks(title="漢字仮名交じりテキストの読み推定") as demo: |
|
gr.Markdown( |
|
""" |
|
# 音声と漢字仮名交じりテキストからふりがなを推定するツール |
|
|
|
音声認識モデルと MeCab の合わせ技でふりがなを推定します。 |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox( |
|
label="テキスト", |
|
placeholder="読みを推定したいテキストを入力してください", |
|
lines=2, |
|
) |
|
audio_input = gr.Audio(label="音声ファイル", type="filepath") |
|
submit_btn = gr.Button("読みを推定", variant="primary") |
|
|
|
with gr.Column(): |
|
reading_output = gr.Textbox(label="推定された読み(ひらがな)", lines=2) |
|
confidence_output = gr.Textbox(label="信頼度") |
|
warnings_output = gr.Textbox(label="警告メッセージ", lines=3) |
|
|
|
|
|
submit_btn.click( |
|
fn=process_audio_gradio, |
|
inputs=[text_input, audio_input], |
|
outputs=[reading_output, confidence_output, warnings_output], |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|