Project Beatrice
Initial commit
d65cea0
#!/usr/bin/env python3
"""
ASR with multiple reading candidates using MeCab and Sudachi.
This module provides functionality to generate and score multiple reading candidates
for Japanese text and use them in speech recognition.
"""
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import math
import tempfile
from pathlib import Path
from typing import List, Tuple, Dict
import fugashi
import gradio as gr
import librosa
import numpy as np
import torch
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet2.bin.asr_inference import Speech2Text
# MeCab辞書の設定
MECAB_DIC_DIR = str(Path(__file__).parent / "mecab-ipadic-neologd")
AUDIO_FILES_DIR = Path(".")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class MecabCandidateGenerator:
"""Generate reading candidates using MeCab with mecab-ipadic-neologd dictionary."""
KATAKANA_LIST = set(
"ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヽヾー・"
)
def __init__(self):
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as tmp:
tmp.write(f"""dicdir = {MECAB_DIC_DIR}\n""")
tmp.flush()
self.tagger = fugashi.GenericTagger(
f"-r {tmp.name} -d {MECAB_DIC_DIR} -Odump"
)
self.tagger_p = fugashi.GenericTagger(
f"-r {tmp.name} -d {MECAB_DIC_DIR} -p -F 0 -E %pc"
)
def _normalize_text(self, text: str) -> str:
"""Normalize full-width numbers and punctuation to half-width."""
# Full-width to half-width mapping
replacements = {
# Numbers
"0": "0",
"1": "1",
"2": "2",
"3": "3",
"4": "4",
"5": "5",
"6": "6",
"7": "7",
"8": "8",
"9": "9",
# Punctuation
"!": "!",
"?": "?",
".": ".",
",": ",",
}
for full, half in replacements.items():
text = text.replace(full, half)
return text
def generate_candidates(self, text: str, nbest: int = 512) -> List[List[List[str]]]:
"""Generate n-best morphological analysis candidates."""
# Normalize numbers and punctuation to half-width
text = self._normalize_text(text)
res = self.tagger.nbest(text, nbest)
candidates = []
candidate = []
for line in res.split("\n"):
if not line.strip():
continue
fields = line.split()
word = fields[1]
if word == "BOS":
pass
elif word == "EOS":
candidates.append(candidate)
candidate = []
else:
candidate.append(fields)
return candidates
def candidate_to_score(self, candidate: List[List[str]]) -> Tuple[float, List[str]]:
"""Calculate score for a morphological analysis candidate."""
query = []
for morph in candidate:
assert len(morph) >= 3, (
f"Expected morph to have at least 3 fields, got {len(morph)}: {morph}"
)
_, original_form, features, *_ = morph
query.append(f"{original_form}\t{features}\n")
query = "".join(query) + "EOS"
result = self.tagger_p.parse(query)
result = int(result.lstrip("0"))
return result, []
def candidate_to_yomi(
self, candidate: List[List[str]], yomi_index: int = 7
) -> Tuple[str, List[str]]:
"""Convert morphological analysis candidate to reading (yomi)."""
result = ""
warning_messages = []
for morph in candidate:
warning_message = ""
if len(morph) < 2:
warning_message = f"[3] Morph has less than 2 fields: {morph[1:3]}"
warning_messages.append(warning_message)
continue
original_form = morph[1]
if len(morph) < 3:
warning_message = f"[3] Morph has less than 3 fields: {morph[1:3]}"
morph = morph + ["*"]
features = morph[2].split(",")
if len(features) <= yomi_index:
reading = ""
if all(
"ぁ" <= c <= "ん" or c in self.KATAKANA_LIST for c in original_form
):
reading = original_form
else:
if not warning_message:
warning_level = (
1
if all(
c in "、。!?!?「」『』【】〔〕[]〈〉《》・"
for c in original_form
)
else 3
)
warning_message = (
f"[{warning_level}] Morph has no reading: {morph[1:3]}"
)
warning_messages.append(warning_message)
continue
else:
reading = features[yomi_index]
i = 0
normalized_reading = ""
while i < len(reading):
char = reading[i]
if "ァ" <= char <= "ン" or char in "ヽヾ":
# カタカナをひらがなに変換
normalized_reading += chr(ord(char) - 96)
elif char == "ー" or "ぁ" <= char <= "ん":
normalized_reading += char
elif char == "ヴ":
if i + 1 < len(reading) and reading[i + 1] in "ァィェォ":
normalized_reading += "ばびべぼ"[
"ァィェォ".index(reading[i + 1])
]
i += 1
else:
normalized_reading += "ぶ"
else:
warning_level = (
1 if char in "、。!?!?「」『』【】〔〕[]〈〉《》・" else 3
)
warning_message = f"[{warning_level}] Unhandled character in reading: {morph[1:3]}"
warning_messages.append(warning_message)
i += 1
result += normalized_reading
return result, warning_messages
class CandidateScorer(BatchScorerInterface):
"""Score ASR hypotheses based on allowed reading candidates."""
def __init__(self, token_list: List[str], device: str = "cpu"):
super().__init__()
self.token_list = token_list
self.eos_id = token_list.index("<sos/eos>")
self.device = device
def set_candidates(self, candidates: List[str]):
"""Set the allowed reading candidates."""
self.candidates = candidates
def score(
self, y: torch.Tensor, state, x=None
): # x is unused but required by interface
"""
Score function for beam search.
Args:
y: prefix token sequence
state: scorer state (unused)
x: encoder feature (unused)
Returns:
scores: token scores
state: updated state
"""
prefix = y.tolist()
assert prefix[0] == self.eos_id, prefix
prefix = [self.token_list[i] for i in prefix[1:]]
prefix = "".join(prefix)
allowed = []
for candidate in self.candidates:
if candidate.startswith(prefix):
remaining = candidate[len(prefix) :]
if remaining:
for i, token in enumerate(self.token_list):
if remaining.startswith(token):
allowed.append(i)
else:
allowed.append(self.eos_id)
allowed = list(set(allowed))
vocab = len(self.token_list)
scores = torch.full((vocab,), float("-inf"), device=self.device)
if allowed:
scores[allowed] = 0.0
else:
scores[self.eos_id] = -10000
return scores, state
class MultiReadingASR:
"""ASR system with multiple reading candidates support."""
def __init__(self, device: str = DEVICE):
"""Initialize ASR components and dictionaries."""
print("Initializing models and dictionaries...")
# Initialize MeCab
self.mecab_generator = MecabCandidateGenerator()
# Initialize ASR model
self.asr = self._setup_asr_model(device)
def _setup_asr_model(self, device: str = "cpu") -> Speech2Text:
"""Setup and configure the ASR model."""
asr = Speech2Text.from_pretrained(
"reazon-research/reazonspeech-espnet-v2",
lm_weight=0,
device=device,
nbest=10,
normalize_length=True,
)
# Filter out non-hiragana tokens
allowed_tokens = set(
"ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちぢっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐゑをんゝゞー"
)
assert len(asr.asr_model.decoder.output_layer.bias) == len(
asr.asr_model.token_list
)
assert len(asr.asr_model.ctc.ctc_lo.bias) == len(asr.asr_model.token_list)
for i, token in enumerate(asr.asr_model.token_list):
if len(token) == 1 and token not in allowed_tokens:
asr.asr_model.decoder.output_layer.bias.data[i] -= 100.0
asr.asr_model.ctc.ctc_lo.bias.data[i] -= 100.0
return asr
def process_audio_with_candidates(
self,
wav_file: Path,
text: str,
verbose: bool = False,
) -> Dict:
"""Process audio file with multiple reading candidates."""
# Load audio
wav, _ = librosa.load(wav_file, sr=16000, mono=True, dtype=np.float32)
# Get relative path from AUDIO_FILES_DIR
try:
relative_path = wav_file.relative_to(AUDIO_FILES_DIR)
file_path = str(relative_path)
except ValueError:
# Fallback to just filename if not under AUDIO_FILES_DIR
file_path = wav_file.name
results = {"file": file_path, "text": text, "readings": {}}
if verbose:
print(f"File: {wav_file.name}")
print(f"Text: {text}")
warning_messages = []
# MeCab candidates
mecab_candidates = self.mecab_generator.generate_candidates(text)
yomi_candidates = []
yomi_candidate_to_indices = {}
for i, c in enumerate(mecab_candidates):
yomi, warnings = self.mecab_generator.candidate_to_yomi(c)
warning_messages.extend(warnings)
yomi_candidates.append(yomi)
if yomi in yomi_candidate_to_indices:
yomi_candidate_to_indices[yomi].append(i)
else:
yomi_candidate_to_indices[yomi] = [i]
yomi_candidates = sorted(set(yomi_candidates))
warning_messages = sorted(set(warning_messages), key=lambda x: (-ord(x[1]), x))
results["yomi_candidates"] = yomi_candidates
results["warnings"] = warning_messages
if verbose:
print(f"Warning messages: {warning_messages}")
# ASR with candidates
scorer = CandidateScorer(self.asr.asr_model.token_list, device=self.asr.device)
scorer.set_candidates(yomi_candidates)
self.asr.maxlenratio = -max(map(len, yomi_candidates)) - 2
self.asr.beam_search.scorers["cand"] = scorer
self.asr.beam_search.full_scorers["cand"] = scorer
self.asr.beam_search.weights["cand"] = 1.0
transcription = self.asr(wav)
results["transcriptions"] = []
for trans, _, _, info in transcription:
if info.score == -math.inf:
continue
candidate_indices = yomi_candidate_to_indices.get(trans)
if candidate_indices is None:
mecab_score = math.inf
else:
mecab_score = min(
self.mecab_generator.candidate_to_score(mecab_candidates[i])[0]
for i in candidate_indices
)
mecab_scale = 800
mecab_score /= mecab_scale
normalized_score = info.score.item() / (len(info.yseq) - 1)
results["transcriptions"].append(
{
"text": trans,
"normalized_score": normalized_score,
"mecab_score": mecab_score,
}
)
if verbose:
print(
f"Transcription: {trans} {normalized_score:.2f} {info.score.item():.2f} {mecab_score:.2f}"
)
transcriptions = results["transcriptions"]
mean_transcription_length = sum(len(t["text"]) for t in transcriptions) / len(
transcriptions
)
for transcription in transcriptions:
adjusted_score = (
transcription["normalized_score"] * mean_transcription_length
)
mecab_score = transcription["mecab_score"]
transcription["ensembled_score"] = (
adjusted_score * 0.2528 + mecab_score * -0.4701
)
best_transcription = max(transcriptions, key=lambda t: t["ensembled_score"])
results["best_transcription"] = best_transcription
results["confidence"] = 1.0 / sum(
math.exp(t["ensembled_score"] - best_transcription["ensembled_score"])
for t in transcriptions
)
return results
# グローバルでASRシステムを初期化(初回ロード時間を短縮)
print("モデルを初期化しています...")
asr_system = MultiReadingASR()
print("初期化が完了しました。")
def process_audio_gradio(text, audio_file):
"""Gradio用の処理関数"""
if not text or not audio_file:
return "テキストと音声ファイルの両方を入力してください。", "", ""
# 音声ファイルのパスを取得
wav_path = Path(audio_file)
# ASRで処理
results = asr_system.process_audio_with_candidates(
wav_file=wav_path, text=text, verbose=False
)
# 最良の読みを取得
best_reading = results["best_transcription"]["text"]
# 信頼度をパーセンテージに変換
confidence = f"{results['confidence'] * 100:.1f}%"
# 警告メッセージを整形
warnings = results.get("warnings", [])
if warnings:
warning_text = "\n".join(warnings)
else:
warning_text = "警告なし"
return best_reading, confidence, warning_text
# Gradio インターフェースの作成
with gr.Blocks(title="漢字仮名交じりテキストの読み推定") as demo:
gr.Markdown(
"""
# 音声と漢字仮名交じりテキストからふりがなを推定するツール
音声認識モデルと MeCab の合わせ技でふりがなを推定します。
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="テキスト",
placeholder="読みを推定したいテキストを入力してください",
lines=2,
)
audio_input = gr.Audio(label="音声ファイル", type="filepath")
submit_btn = gr.Button("読みを推定", variant="primary")
with gr.Column():
reading_output = gr.Textbox(label="推定された読み(ひらがな)", lines=2)
confidence_output = gr.Textbox(label="信頼度")
warnings_output = gr.Textbox(label="警告メッセージ", lines=3)
# ボタンクリック時の処理
submit_btn.click(
fn=process_audio_gradio,
inputs=[text_input, audio_input],
outputs=[reading_output, confidence_output, warnings_output],
)
# Hugging Face Spaces用の起動
if __name__ == "__main__":
demo.launch()