Spaces:
Running
Running
import json | |
import os | |
import tempfile | |
import platform | |
import re | |
from pathlib import Path | |
from typing import Any, Union, Optional, List, Dict | |
import shutil | |
from datetime import datetime | |
import hashlib # ハッシュ計算のために追加 | |
import gradio as gr | |
import numpy as np | |
import torch | |
from safetensors import safe_open | |
from safetensors.torch import save_file | |
from config import get_path_config | |
from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME | |
from style_bert_vits2.logging import logger | |
from style_bert_vits2.tts_model import TTSModel, TTSModelHolder | |
# マージ対象の重みキーのプレフィックス | |
voice_keys = ["dec"] | |
voice_pitch_keys = ["flow"] | |
speech_style_keys = ["enc_p"] | |
tempo_keys = ["sdp", "dp"] | |
# 定数 | |
MAX_MODELS = 10 | |
MAX_STYLES = 20 | |
MAX_MODELS_TO_KEEP = 20 # 一時保存するマージモデルの最大数(例: 直近20個を保持) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
path_config = get_path_config() | |
assets_root = path_config.assets_root | |
# OSに応じたデフォルトの一時保存先を決定 | |
def get_default_temp_dir(): | |
if platform.system() == "Windows": | |
return "R:\\Temp" | |
elif platform.system() == "Linux": | |
# Hugging Face Spacesなどの環境を考慮して、一般的なパスに変更 | |
return "/tmp/sbv2_merger_cache" | |
else: | |
return tempfile.gettempdir() | |
# デフォルトの一時保存先 | |
DEFAULT_TEMP_SAVE_DIR = get_default_temp_dir() | |
def sanitize_filename(name: str, replacement: str = "_") -> str: | |
""" | |
ファイル名やディレクトリ名として安全な文字列にサニタイズする。 | |
許可される文字は英数字、アンダースコア、ハイフン。それ以外は置換文字に変換。 | |
""" | |
# 許可される文字を定義する正規表現パターン | |
pattern = re.compile(r'[^a-zA-Z0-9_\- ]+') | |
sanitized = pattern.sub(replacement, name) | |
# 連続する置換文字を一つにまとめる | |
while replacement * 2 in sanitized: | |
sanitized = sanitized.replace(replacement * 2, replacement) | |
# 先頭と末尾の置換文字を削除 | |
sanitized = sanitized.strip(replacement) | |
# 空になった場合はデフォルト名を返す | |
return sanitized if sanitized else "sanitized_name" | |
def _manage_temp_space(base_dir: Path, required_mb: int, max_models_to_keep: int) -> None: | |
""" | |
一時ディレクトリの容量を管理し、必要に応じて古いディレクトリを削除する。 | |
Args: | |
base_dir (Path): 一時保存先のルートディレクトリ。 | |
required_mb (int): 新しいモデルを保存するために最低限必要な空き容量(MB単位)。 | |
max_models_to_keep (int): 一時的に保持するマージモデルの最大数。 | |
Raises: | |
OSError: 空き容量が不足している場合、またはディスクI/Oエラーが発生した場合。 | |
""" | |
if not base_dir.exists() or not base_dir.is_dir(): | |
logger.warning(f"Temporary base directory does not exist or is not a directory: {base_dir}") | |
# ディレクトリが存在しない場合、後続のmerge_models_weighted_sum内で作成される | |
return | |
# "recipe.json" を含むディレクトリを管理対象としてリストアップ | |
temp_dirs = [] | |
for temp_dir in base_dir.iterdir(): | |
# recipe.jsonの存在でマージモデルディレクトリか判断 | |
if temp_dir.is_dir() and (temp_dir / "recipe.json").exists(): | |
try: | |
# 作成日時 (ctime) でソートするためにタプルに追加 | |
temp_dirs.append((temp_dir.stat().st_ctime, temp_dir)) | |
except OSError as e: | |
logger.warning(f"Could not get stats for {temp_dir}: {e}") | |
continue | |
# 作成日時が古い順にソート | |
temp_dirs.sort() | |
deleted_count = 0 | |
# まず、最大保持数を超える分を削除 | |
while len(temp_dirs) >= max_models_to_keep: | |
# 最も古いものを取得し、リストから削除 | |
_ , oldest_dir_path = temp_dirs.pop(0) | |
try: | |
shutil.rmtree(oldest_dir_path) | |
logger.info(f"Removed oldest temporary directory to enforce max_models_to_keep: {oldest_dir_path}") | |
deleted_count += 1 | |
except OSError as e: | |
logger.warning(f"Failed to remove old temporary directory {oldest_dir_path}: {e}") | |
# 次に、必要な空き容量が確保されているか確認し、不足していればさらに削除 | |
try: | |
total_b, used_b, free_b = shutil.disk_usage(base_dir) | |
free_mb = free_b / (1024 * 1024) | |
# 必要な空き容量 (required_mb は既にモデルサイズ + 余裕のMBなので、ここでは追加のバッファは設けない) | |
required_free_mb = required_mb | |
while free_mb < required_free_mb and temp_dirs: | |
# 最も古いものを取得し、リストから削除 | |
_ , oldest_dir_path = temp_dirs.pop(0) | |
try: | |
shutil.rmtree(oldest_dir_path) | |
logger.info(f"Removed oldest temporary directory to free up space: {oldest_dir_path}") | |
deleted_count += 1 | |
# 空き容量を再計算 | |
total_b, used_b, free_b = shutil.disk_usage(base_dir) | |
free_mb = free_b / (1024 * 1024) | |
except OSError as e: | |
logger.warning(f"Failed to remove old temporary directory {oldest_dir_path}: {e}") | |
if free_mb < required_free_mb: | |
logger.error(f"Not enough free space in {base_dir} even after cleanup. " | |
f"Required: {required_free_mb:.2f} MB, Available: {free_mb:.2f} MB.") | |
raise OSError(f"一時保存先ディレクトリ '{base_dir}' の空き容量が不足しています。手動で不要なファイルを削除してください。") | |
except OSError as e: | |
logger.error(f"Error checking disk usage for {base_dir}: {e}") | |
# ディスク容量チェック自体に失敗した場合もエラーを報告 | |
raise OSError(f"ディスク容量の確認中にエラーが発生しました: {e}") | |
logger.info(f"Temporary space management completed. {deleted_count} directories removed.") | |
def load_safetensors(model_path: Union[str, Path]) -> dict[str, torch.Tensor]: | |
result: dict[str, torch.Tensor] = {} | |
with safe_open(model_path, framework="pt", device="cpu") as f: | |
for k in f.keys(): | |
result[k] = f.get_tensor(k) | |
return result | |
def load_config(model_name: str) -> dict[str, Any]: | |
with open(assets_root / model_name / "config.json", encoding="utf-8") as f: | |
config = json.load(f) | |
return config | |
def save_config(config: dict[str, Any], model_name: str): | |
output_dir = assets_root / model_name | |
output_dir.mkdir(parents=True, exist_ok=True) | |
with open(output_dir / "config.json", "w", encoding="utf-8") as f: | |
json.dump(config, f, indent=2, ensure_ascii=False) | |
def save_recipe(recipe: dict[str, Any], model_name: str): | |
# この関数はassets_rootの下に保存するためのもの | |
output_dir = assets_root / model_name | |
output_dir.mkdir(parents=True, exist_ok=True) | |
with open(output_dir / "recipe.json", "w", encoding="utf-8") as f: | |
json.dump(recipe, f, indent=2, ensure_ascii=False) | |
def load_style_vectors(model_name: str) -> np.ndarray: | |
return np.load(assets_root / model_name / "style_vectors.npy") | |
def save_style_vectors(style_vectors: np.ndarray, model_name: str): | |
output_dir = assets_root / model_name | |
output_dir.mkdir(parents=True, exist_ok=True) | |
np.save(output_dir / "style_vectors.npy", style_vectors) | |
def merge_style_weighted_sum( | |
model_names: list[str], | |
style_tuple_list: list[tuple], | |
coeffs_list: list[list[float]], | |
base_config: dict[str, Any], | |
) -> tuple[np.ndarray, dict[str, Any]]: | |
model_configs = [load_config(name) for name in model_names] | |
style_vectors_list = [load_style_vectors(name) for name in model_names] | |
style2id_list = [config["data"]["style2id"] for config in model_configs] | |
new_style_vecs = [] | |
new_style2id = {} | |
for i, style_tuple in enumerate(style_tuple_list): | |
output_style_name = style_tuple[-1] | |
input_style_names = style_tuple[:-1] | |
coeffs = coeffs_list[i] | |
new_style = 0 | |
for j, style_name in enumerate(input_style_names): | |
if style_name not in style2id_list[j]: | |
raise ValueError(f"スタイル '{style_name}' はモデル '{model_names[j]}' にありません。") | |
style_id = style2id_list[j][style_name] | |
style_vector = style_vectors_list[j][style_id] | |
new_style += coeffs[j] * style_vector | |
new_style_vecs.append(new_style) | |
new_style2id[output_style_name] = len(new_style_vecs) - 1 | |
new_style_vecs_np = np.array(new_style_vecs) | |
new_config = base_config.copy() | |
new_config["data"]["num_styles"] = len(new_style2id) | |
new_config["data"]["style2id"] = new_style2id | |
return new_style_vecs_np, new_config | |
def get_merge_hash( | |
model_paths: List[str], | |
voice_coeffs: List[float], | |
voice_pitch_coeffs: List[float], | |
speech_style_coeffs: List[float], | |
tempo_coeffs: List[float], | |
) -> str: | |
"""マージの内容(モデルパスと比率)から決定論的なハッシュ値を生成する。""" | |
# 順序が重要なので、リストはソートしない | |
data_to_hash = { | |
# ハッシュ生成前にパスを文字列に変換 | |
"model_paths": [str(p) for p in model_paths], | |
"voice_coeffs": voice_coeffs, | |
"voice_pitch_coeffs": voice_pitch_coeffs, | |
"speech_style_coeffs": speech_style_coeffs, | |
"tempo_coeffs": tempo_coeffs, | |
} | |
# JSON文字列に変換し、キーをソートして一貫性を確保 | |
data_str = json.dumps(data_to_hash, sort_keys=True, ensure_ascii=False) | |
# SHA256ハッシュを計算し、16進数文字列として返す | |
return hashlib.sha256(data_str.encode('utf-8')).hexdigest() | |
def merge_models_weighted_sum( | |
model_paths: list[str], | |
voice_coeffs: list[float], | |
voice_pitch_coeffs: list[float], | |
speech_style_coeffs: list[float], | |
tempo_coeffs: list[float], | |
) -> tuple[dict[str, Any], Path]: | |
resolved_temp_base_dir: Path | |
default_path_obj = Path(DEFAULT_TEMP_SAVE_DIR) | |
try: | |
default_path_obj.mkdir(parents=True, exist_ok=True) | |
resolved_temp_base_dir = default_path_obj | |
logger.info(f"Using default temporary directory: {resolved_temp_base_dir}") | |
except OSError: | |
resolved_temp_base_dir = Path(tempfile.gettempdir()) | |
resolved_temp_base_dir.mkdir(parents=True, exist_ok=True) | |
logger.info(f"Default temporary directory '{DEFAULT_TEMP_SAVE_DIR}' not available or could not be created. Using default system temporary directory: {resolved_temp_base_dir}") | |
# マージ内容から一意のハッシュを生成 | |
merge_hash = get_merge_hash( | |
model_paths, voice_coeffs, voice_pitch_coeffs, speech_style_coeffs, tempo_coeffs | |
) | |
# 内部で使用するモデル名とディレクトリ名を生成 | |
name_parts = [] | |
for i, model_path in enumerate(model_paths): | |
model_folder_name = Path(model_path).parent.name | |
# voice_coeffsのインデックスが範囲内であることを確認 | |
if i < len(voice_coeffs) and voice_coeffs[i] > 0: | |
coeff_val = voice_coeffs[i] | |
# 比率を100倍して整数に丸め、文字列にする (例: 0.75 -> "75", 1.0 -> "100") | |
coeff_str = f"{int(round(coeff_val * 100)):d}" | |
# モデル名と比率の間にアンダースコアを挿入し、比率の後ろに'p'を追加 | |
name_parts.append(f"{model_folder_name}_{coeff_str}p") | |
# 生成された名前を結合し、最後にサニタイズ | |
internal_output_name = "_".join(name_parts) | |
internal_output_name = sanitize_filename(internal_output_name) | |
if not internal_output_name: # 全ての比率が0だった場合など | |
internal_output_name = "merged_model" | |
merged_output_dir = resolved_temp_base_dir / internal_output_name | |
# --- 既存モデルのチェック --- | |
if merged_output_dir.exists() and (merged_output_dir / "recipe.json").exists(): | |
try: | |
with open(merged_output_dir / "recipe.json", "r", encoding="utf-8") as f: | |
existing_recipe = json.load(f) | |
if existing_recipe.get("merge_hash") == merge_hash: | |
logger.info(f"同一内容のマージ済みモデルが既に存在します。既存のモデルを再利用します: {merged_output_dir}") | |
with open(merged_output_dir / "config.json", "r", encoding="utf-8") as f: | |
config = json.load(f) | |
styles = np.load(merged_output_dir / "style_vectors.npy") | |
merged_data_info = { | |
"name": config.get("model_name", internal_output_name), | |
"merged_dir_path": str(merged_output_dir), | |
"config": config, | |
"styles": styles, | |
"recipe": existing_recipe | |
} | |
return merged_data_info, merged_output_dir | |
else: | |
logger.warning(f"同じ名前の一時ディレクトリが存在しますが、マージ内容が異なります。古いディレクトリを削除して再マージします: {merged_output_dir}") | |
shutil.rmtree(merged_output_dir) | |
except (OSError, json.JSONDecodeError, KeyError) as e: | |
logger.warning(f"既存の一時モデルの読み込み/削除に失敗しました。再マージを実行します。エラー: {e}") | |
if merged_output_dir.exists(): | |
shutil.rmtree(merged_output_dir, ignore_errors=True) | |
# --- 以下、新規マージ処理 --- | |
# 必要な空き容量を見積もる (最も大きい入力モデルのサイズを基準にする) | |
max_input_model_size_bytes = 0 | |
for p in model_paths: | |
try: | |
size = os.path.getsize(p) | |
if size > max_input_model_size_bytes: | |
max_input_model_size_bytes = size | |
except OSError: | |
logger.warning(f"Could not get size for model file: {p}") | |
continue | |
# MB単位に変換し、最低250MB + αの余裕を確保 (+100MBはオーバーヘッドや一時ファイル作成の余裕) | |
required_space_mb = max(int(max_input_model_size_bytes / (1024 * 1024)) + 100, 250) | |
try: | |
_manage_temp_space(resolved_temp_base_dir, required_space_mb, MAX_MODELS_TO_KEEP) | |
except OSError as e: | |
raise gr.Error(str(e)) # Gradioのエラーとして表示 | |
# マージ結果を保存するディレクトリを作成 | |
merged_output_dir.mkdir(exist_ok=True, parents=True) | |
model_weights_list = [load_safetensors(p) for p in model_paths] | |
merged_model_weight = model_weights_list[0].copy() | |
for key in merged_model_weight.keys(): | |
new_tensor = torch.zeros_like(merged_model_weight[key]) | |
if any(key.startswith(prefix) for prefix in voice_keys): | |
coeffs = voice_coeffs | |
elif any(key.startswith(prefix) for prefix in voice_pitch_keys): | |
coeffs = voice_pitch_coeffs | |
elif any(key.startswith(prefix) for prefix in speech_style_keys): | |
coeffs = speech_style_coeffs | |
elif any(key.startswith(prefix) for prefix in tempo_keys): | |
coeffs = tempo_coeffs | |
else: | |
# マージ対象でないキーは、ベースモデル(リストの最初のモデル)の重みをそのまま使用 | |
merged_model_weight[key] = model_weights_list[0][key] | |
continue | |
for i, model_weights in enumerate(model_weights_list): | |
if key in model_weights: | |
if i < len(coeffs) and coeffs[i] is not None: | |
current_coeff = coeffs[i] | |
else: | |
logger.warning(f"Coefficient for model {i+1} is missing or None for key {key}. Using 0.0.") | |
current_coeff = 0.0 | |
new_tensor += current_coeff * model_weights[key] | |
merged_model_weight[key] = new_tensor | |
# recipeに保存する前に、Pathオブジェクトを文字列に変換する | |
recipe = { | |
"method": "weighted_sum", | |
"model_paths": [str(p) for p in model_paths], | |
"voice_coeffs": voice_coeffs, | |
"voice_pitch_coeffs": voice_pitch_coeffs, | |
"speech_style_coeffs": speech_style_coeffs, | |
"tempo_coeffs": tempo_coeffs, | |
"merge_hash": merge_hash, | |
"temporary_merged_dir": str(merged_output_dir), | |
"internal_model_name": internal_output_name | |
} | |
model_names_from_paths = [Path(p).parent.name for p in model_paths] | |
style_vectors_list = [load_style_vectors(name) for name in model_names_from_paths] | |
new_neutral_vector = np.zeros_like(style_vectors_list[0][0]) | |
for i, style_vectors in enumerate(style_vectors_list): | |
new_neutral_vector += speech_style_coeffs[i] * style_vectors[0] | |
new_style_vectors = np.array([new_neutral_vector]) | |
base_model_name = Path(model_paths[0]).parent.name | |
new_config = load_config(base_model_name) | |
new_config["model_name"] = internal_output_name | |
new_config["data"]["num_styles"] = 1 | |
new_config["data"]["style2id"] = {DEFAULT_STYLE: 0} | |
# マージされたファイルを実際にディスクに保存 | |
save_file(merged_model_weight, merged_output_dir / f"{internal_output_name}.safetensors") | |
with open(merged_output_dir / "config.json", "w", encoding="utf-8") as f: | |
json.dump(new_config, f, indent=2, ensure_ascii=False) | |
np.save(merged_output_dir / "style_vectors.npy", new_style_vectors) | |
with open(merged_output_dir / "recipe.json", "w", encoding="utf-8") as f: | |
json.dump(recipe, f, indent=2, ensure_ascii=False) | |
merged_data_info = { | |
"name": internal_output_name, | |
"merged_dir_path": str(merged_output_dir), | |
"config": new_config, | |
"styles": new_style_vectors, | |
"recipe": recipe | |
} | |
return merged_data_info, merged_output_dir | |
def merge_models_gr( | |
model_count: int, ui_mode: str, input_mode: str, | |
merged_data_state: Optional[dict], | |
model_holder: TTSModelHolder, | |
*args | |
) -> tuple[str, gr.Dropdown, Optional[dict]]: | |
args_list = list(args) | |
model_names = args_list[:model_count] | |
if not all(model_names): | |
return "Error: 必要なモデルフォルダが選択されていません。", \ | |
gr.Dropdown(choices=[DEFAULT_STYLE], value=DEFAULT_STYLE), None | |
model_paths = [] | |
for name in model_names: | |
files = model_holder.model_files_dict.get(name, []) | |
if not files: | |
return f"Error: モデル '{name}' に safetensors ファイルが見つかりません。", \ | |
gr.Dropdown(choices=[DEFAULT_STYLE], value=DEFAULT_STYLE), None | |
# model_holder構築時にソート済みなので、最初のファイルを選択 | |
model_paths.append(files[0]) | |
bulk_slider_start = MAX_MODELS | |
bulk_num_start = bulk_slider_start + MAX_MODELS | |
ind_slider_start = bulk_num_start + MAX_MODELS | |
ind_num_start = ind_slider_start + MAX_MODELS * 4 | |
if ui_mode == "一括": | |
coeffs = (args_list[bulk_slider_start : bulk_slider_start + model_count] if input_mode == "スライダー" | |
else args_list[bulk_num_start : bulk_num_start + model_count]) | |
coeffs = [c if c is not None else 0.0 for c in coeffs] | |
voice_coeffs, voice_pitch_coeffs, speech_style_coeffs, tempo_coeffs = (coeffs,)*4 | |
else: | |
if input_mode == "スライダー": | |
voice_coeffs = args_list[ind_slider_start : ind_slider_start + model_count] | |
voice_pitch_coeffs = args_list[ind_slider_start + MAX_MODELS : ind_slider_start + MAX_MODELS + model_count] | |
speech_style_coeffs = args_list[ind_slider_start + MAX_MODELS*2 : ind_slider_start + MAX_MODELS*2 + model_count] | |
tempo_coeffs = args_list[ind_slider_start + MAX_MODELS*3 : ind_slider_start + MAX_MODELS*3 + model_count] | |
else: | |
voice_coeffs = args_list[ind_num_start : ind_num_start + model_count] | |
voice_pitch_coeffs = args_list[ind_num_start + MAX_MODELS : ind_num_start + MAX_MODELS + model_count] | |
speech_style_coeffs = args_list[ind_num_start + MAX_MODELS*2 : ind_num_start + MAX_MODELS*2 + model_count] | |
tempo_coeffs = args_list[ind_num_start + MAX_MODELS*3 : ind_num_start + MAX_MODELS*3 + model_count] | |
voice_coeffs = [c if c is not None else 0.0 for c in voice_coeffs] | |
voice_pitch_coeffs = [c if c is not None else 0.0 for c in voice_pitch_coeffs] | |
speech_style_coeffs = [c if c is not None else 0.0 for c in speech_style_coeffs] | |
tempo_coeffs = [c if c is not None else 0.0 for c in tempo_coeffs] | |
try: | |
merged_data_info, merged_output_dir = merge_models_weighted_sum( | |
model_paths, list(voice_coeffs), list(voice_pitch_coeffs), | |
list(speech_style_coeffs), list(tempo_coeffs) | |
) | |
except Exception as e: | |
logger.error(f"Error during model merge: {e}", exc_info=True) | |
return gr.Error(f"モデルマージ中にエラーが発生しました: {e}"), \ | |
gr.Dropdown(choices=[DEFAULT_STYLE], value=DEFAULT_STYLE), None | |
return f"Success: モデルファイルをマージしました。", \ | |
gr.Dropdown(choices=[DEFAULT_STYLE], value=DEFAULT_STYLE), merged_data_info | |
def merge_style_gr_common( | |
merged_data: dict, new_styles: np.ndarray, new_config: dict | |
) -> tuple[str, gr.Dropdown, dict]: | |
if not merged_data: | |
return "Error: 先にモデルファイルのマージを実行してください。", gr.Dropdown(), merged_data | |
merged_data["styles"] = new_styles | |
merged_data["config"] = new_config | |
merged_dir_path = Path(merged_data["merged_dir_path"]) | |
np.save(merged_dir_path / "style_vectors.npy", new_styles) | |
with open(merged_dir_path / "config.json", "w", encoding="utf-8") as f: | |
json.dump(new_config, f, indent=2, ensure_ascii=False) | |
style_names = list(new_config["data"]["style2id"].keys()) | |
# スタイルの表示名を更新 | |
style_map = {name: name for name in style_names} | |
choices = [(disp, internal) for internal, disp in style_map.items()] | |
default_value = style_names[0] if style_names else None | |
return f"Success: スタイルを更新し、一時ファイルに保存しました。", \ | |
gr.Dropdown(choices=choices, value=default_value), \ | |
merged_data | |
def merge_style_weighted_sum_gr( | |
merged_data: Optional[dict], | |
model_count: int, | |
style_count: int, | |
*args: Any | |
) -> tuple[str, gr.Dropdown, Optional[dict]]: | |
if not merged_data: | |
return "Error: 先にモデルファイルのマージを実行してください。", gr.Dropdown(), merged_data | |
arg_list = list(args) | |
style_input_comps_len = MAX_STYLES * (MAX_MODELS + 1) | |
flat_style_inputs = arg_list[:style_input_comps_len] | |
is_extended = arg_list[style_input_comps_len] | |
flat_style_ratio_inputs = arg_list[style_input_comps_len + 1:] | |
# モデルパス情報をレシピから取得 | |
if "recipe" not in merged_data or "model_paths" not in merged_data["recipe"]: | |
return "Error: マージされたモデルのレシピから元のモデルパス情報が見つかりません。モデルマージを再実行してください。", \ | |
gr.Dropdown(choices=[(DEFAULT_STYLE, DEFAULT_STYLE)]), merged_data | |
model_name_dropdown_values = [Path(p).parent.name for p in merged_data["recipe"]["model_paths"]][:model_count] | |
# 比率リストを構築 | |
coeffs_list = [] | |
if is_extended: | |
logger.info("拡張スタイルマージを実行します。UIから比率を取得します。") | |
ratio_cursor = 0 | |
for i in range(style_count): | |
all_ratios_for_row = flat_style_ratio_inputs[ratio_cursor : ratio_cursor + MAX_MODELS] | |
current_ratios = all_ratios_for_row[:model_count] | |
current_ratios = [r if r is not None else 0.0 for r in current_ratios] | |
coeffs_list.append(current_ratios) | |
ratio_cursor += MAX_MODELS | |
else: | |
logger.info("通常スタイルマージを実行します。「話し方」の比率を使用します。") | |
if "recipe" not in merged_data or "speech_style_coeffs" not in merged_data["recipe"]: | |
return "Error: マージされたモデルのレシピから話し方の比率情報が見つかりません。モデルマージを再実行してください。", \ | |
gr.Dropdown(choices=[(DEFAULT_STYLE, DEFAULT_STYLE)]), merged_data | |
speech_style_coeffs_from_model_merge = merged_data["recipe"]["speech_style_coeffs"] | |
if len(speech_style_coeffs_from_model_merge) < model_count: | |
return "Error: 話し方比率の数がマージモデル数と一致しません。モデルマージを再実行してください。", \ | |
gr.Dropdown(choices=[(DEFAULT_STYLE, DEFAULT_STYLE)]), merged_data | |
for _ in range(style_count): | |
coeffs_list.append(speech_style_coeffs_from_model_merge[:model_count]) | |
# スタイルタプルリストを構築 | |
style_tuple_list = [] | |
style_cursor = 0 | |
for i in range(style_count): | |
all_styles_for_row = flat_style_inputs[style_cursor : style_cursor + MAX_MODELS] | |
input_style_names = all_styles_for_row[:model_count] | |
output_style_name = flat_style_inputs[style_cursor + MAX_MODELS] | |
if not output_style_name or not output_style_name.strip(): | |
return f"Error: スタイル行 {i+1} の出力スタイル名が空です。入力してください。", \ | |
gr.Dropdown(choices=list(merged_data["config"]["data"]["style2id"].keys())), merged_data | |
style_tuple_list.append(tuple(input_style_names + [output_style_name])) | |
style_cursor += (MAX_MODELS + 1) | |
try: | |
new_styles, new_config = merge_style_weighted_sum( | |
model_name_dropdown_values, | |
style_tuple_list, coeffs_list, merged_data["config"], | |
) | |
except ValueError as e: | |
return f"Error: {e}", gr.Dropdown(choices=list(merged_data["config"]["data"]["style2id"].keys())), merged_data | |
except Exception as e: | |
logger.error(f"Error during style merge: {e}", exc_info=True) | |
return f"Error: スタイルマージ中にエラーが発生しました: {e}", \ | |
gr.Dropdown(choices=list(merged_data["config"]["data"]["style2id"].keys())), merged_data | |
# レシピにスタイルマージ情報を記録(拡張モードの場合のみ) | |
if is_extended: | |
merged_data["recipe"]["extended_style_coeffs_list"] = coeffs_list | |
elif "extended_style_coeffs_list" in merged_data["recipe"]: | |
del merged_data["recipe"]["extended_style_coeffs_list"] | |
return merge_style_gr_common(merged_data, new_styles, new_config) | |
def simple_tts( | |
text: str, | |
style: str = DEFAULT_STYLE, | |
style_weight: float = 1.0, | |
merged_data: Optional[dict] = None, | |
) -> tuple[str, Optional[tuple[int, np.ndarray]]]: | |
if not merged_data: | |
return "Error: 先にモデルをマージしてください。", None | |
merged_dir_path_str = merged_data.get("merged_dir_path") | |
if not merged_dir_path_str: | |
return "Error: マージされたモデルのファイルパスが見つかりません。モデルマージを再実行してください。", None | |
merged_dir_path = Path(merged_dir_path_str) | |
model_name_for_file = merged_data["name"] | |
tmp_model_path = merged_dir_path / f"{model_name_for_file}.safetensors" | |
tmp_config_path = merged_dir_path / "config.json" | |
tmp_style_path = merged_dir_path / "style_vectors.npy" | |
if not all([tmp_model_path.exists(), tmp_config_path.exists(), tmp_style_path.exists()]): | |
return f"Error: 一時モデルファイルが見つかりません。パス: {merged_dir_path_str}", None | |
try: | |
model = TTSModel(tmp_model_path, tmp_config_path, tmp_style_path, device) | |
audio = model.infer(text, style=style, style_weight=style_weight) | |
return f"Success: マージモデルから音声を生成しました。", audio | |
except Exception as e: | |
logger.error(f"Error during TTS: {e}", exc_info=True) | |
return f"Error: 音声合成中にエラーが発生しました: {e}", None | |
def _get_style_map_for_model(model_name: str) -> Dict[str, str]: | |
""" | |
モデル名から内部スタイル名と表示名のマッピングを取得する。 | |
style_settings.jsonがあればそれを優先し、なければconfig.jsonから生成する。 | |
""" | |
if not model_name: | |
return {DEFAULT_STYLE: DEFAULT_STYLE} | |
model_dir = assets_root / model_name | |
style_settings_path = model_dir / "style_settings.json" | |
config_path = model_dir / "config.json" | |
# まずconfig.jsonからベースのスタイルリストを取得 | |
try: | |
with open(config_path, "r", encoding="utf-8") as f: | |
config = json.load(f) | |
# style2idのキーが内部名 | |
base_style_names = list(config["data"]["style2id"].keys()) | |
# デフォルトのマップを作成(表示名=内部名) | |
style_map = {name: name for name in base_style_names} | |
except (FileNotFoundError, KeyError, json.JSONDecodeError) as e: | |
logger.warning(f"モデル '{model_name}' のconfig.json読み込みに失敗: {e}") | |
return {DEFAULT_STYLE: DEFAULT_STYLE} | |
# style_settings.jsonが存在すれば、表示名を上書き | |
if style_settings_path.exists(): | |
try: | |
with open(style_settings_path, "r", encoding="utf-8") as f: | |
settings = json.load(f) | |
# settingsの "styles" の各キー(内部名)でループ | |
for internal_name, style_info in settings.get("styles", {}).items(): | |
# その内部名がconfig.jsonに存在し、かつdisplay_nameがある場合のみ上書き | |
if internal_name in style_map and "display_name" in style_info: | |
style_map[internal_name] = style_info["display_name"] | |
logger.info(f"モデル '{model_name}' のstyle_settings.jsonから表示名を読み込みました。") | |
except (json.JSONDecodeError, KeyError) as e: | |
logger.warning(f"モデル '{model_name}' のstyle_settings.jsonの解析に失敗しました。config.jsonのスタイル名を使用します。エラー: {e}") | |
# エラーが起きても、config.jsonから作ったstyle_mapでフォールバックできる | |
return style_map | |
def get_styles_for_all_models(model_count: int, *model_names: str): | |
all_styles_maps = [] | |
active_model_names = model_names[:model_count] | |
for model_name in active_model_names: | |
# 新しいヘルパー関数を呼び出す | |
style_map = _get_style_map_for_model(model_name) | |
all_styles_maps.append(style_map) | |
updates = [] | |
for i in range(MAX_STYLES): | |
for j in range(MAX_MODELS): | |
if j < model_count: | |
style_map = all_styles_maps[j] | |
# (表示名, 内部名) のタプルのリストに変換 | |
choices = [(disp, internal) for internal, disp in style_map.items()] | |
# デフォルト値は内部名で指定 | |
default_value = DEFAULT_STYLE if DEFAULT_STYLE in style_map else (choices[0][1] if choices else None) | |
updates.append(gr.Dropdown(choices=choices, value=default_value)) | |
else: | |
# 非アクティブなモデルのスロット | |
updates.append(gr.Dropdown(choices=[(DEFAULT_STYLE, DEFAULT_STYLE)], value=DEFAULT_STYLE)) | |
return updates | |
def set_equal_ratio(model_count: int): | |
if model_count <= 0: | |
return [gr.update(value=0.0)] * (MAX_MODELS * 10) | |
base_ratio = 1.0 / model_count | |
ratios = [round(base_ratio, 2)] * model_count | |
# 合計が1.0になるように最後の要素で調整 | |
current_sum = sum(ratios) | |
diff = 1.0 - current_sum | |
if model_count > 0: | |
ratios[-1] += diff | |
ratios[-1] = round(ratios[-1], 2) | |
full_ratios = ratios + [0.0] * (MAX_MODELS - model_count) | |
updates = [] | |
for _ in range(10): # bulk_sliders/nums (2) + ind_sliders/nums (8) | |
updates.extend([gr.update(value=r) for r in full_ratios]) | |
return updates | |
def update_default_style_name(*args: Any) -> Any: | |
""" | |
選択された入力スタイルに基づいて、デフォルトの出力スタイル名を生成する。 | |
表示名を正しく反映し、単一スタイルの場合はその名前をそのまま使用する。 | |
""" | |
# 引数をアンパック | |
model_names = args[:MAX_MODELS] | |
selected_internal_styles = args[MAX_MODELS : MAX_MODELS + MAX_MODELS] | |
model_count = args[-1] | |
if not isinstance(model_count, int) or model_count <= 0: | |
return gr.update() | |
active_model_names = model_names[:model_count] | |
active_selected_styles = selected_internal_styles[:model_count] | |
display_names_to_join = [] | |
for i, internal_style in enumerate(active_selected_styles): | |
if internal_style and internal_style != DEFAULT_STYLE: | |
# 対応するモデル名を取得 | |
model_name = active_model_names[i] | |
# スタイルマップを取得 | |
style_map = _get_style_map_for_model(model_name) | |
# 内部名から表示名を取得 (見つからなければ内部名をそのまま使用) | |
display_name = style_map.get(internal_style, internal_style) | |
display_names_to_join.append(display_name) | |
if len(display_names_to_join) == 1: | |
# 有効なスタイルが1つだけの場合、その表示名をそのまま使用 | |
new_name = display_names_to_join[0] | |
elif len(display_names_to_join) > 1: | |
# 複数ある場合はアンダースコアで連結 | |
new_name = "_".join(display_names_to_join) | |
else: | |
# 有効なスタイルがない場合はデフォルト | |
new_name = DEFAULT_STYLE | |
return gr.update(value=new_name) | |
def set_default_style_ratios(is_extended: bool, merged_data: Optional[dict], model_count: int): | |
"""拡張スタイルマージが有効になったとき、比率入力欄に「話し方」の比率をデフォルト値として設定する。""" | |
if not is_extended or not merged_data or model_count <= 0: | |
# 拡張モードでない場合や、マージデータがない場合は0.0を返す | |
updates = [gr.update(value=0.0) for _ in range(MAX_STYLES * MAX_MODELS)] | |
return updates | |
try: | |
# レシピから「話し方」の比率を取得 | |
speech_style_coeffs = merged_data["recipe"]["speech_style_coeffs"] | |
# 有効なモデル数に合わせて比率リストを作成し、残りは0.0で埋める | |
coeffs = speech_style_coeffs[:model_count] + [0.0] * (MAX_MODELS - model_count) | |
except (KeyError, TypeError, AttributeError): | |
# レシピや比率が存在しない場合は、すべて0.0にする | |
coeffs = [0.0] * MAX_MODELS | |
# 全てのスタイル行に同じデフォルト比率を適用する | |
updates = [] | |
for _ in range(MAX_STYLES): | |
updates.extend([gr.update(value=c) for c in coeffs]) | |
return updates | |
# ========================================================================= | |
# ★★★★★★★★★★★★★★★★★★★ 修正箇所 ★★★★★★★★★★★★★★★★★★★ | |
# ========================================================================= | |
def fn_model_sort_key(name: str): | |
"""FNモデルを自然順ソートするためのキー関数。FN以外のモデルも扱う。""" | |
match = re.match(r"FN(\d+)", name) | |
if match: | |
# FNモデルは (0, 数値) のタプルをキーにする | |
return (0, int(match.group(1))) | |
# FN以外のモデル (whisperや標準モデル) は (1, 名前) をキーにする | |
return (1, name) | |
def get_fn_models(model_list: List[str]) -> List[str]: | |
"""FNシリーズとwhisperモデルをリストから抽出し、自然順ソートして返す""" | |
fn_pattern = re.compile(r"^FN([1-9]|10)$") | |
target_models = [name for name in model_list if fn_pattern.match(name) or name == "whisper"] | |
# キー関数でソート | |
return sorted(target_models, key=fn_model_sort_key) | |
# ========================================================================= | |
# ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★ | |
def get_standard_models(model_list: List[str]) -> List[str]: | |
"""FNシリーズとwhisper以外のモデルをリストから抽出する""" | |
fn_models_set = set(get_fn_models(model_list)) | |
return sorted([name for name in model_list if name not in fn_models_set]) | |
def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: | |
all_model_names = model_holder.model_names | |
if not all_model_names: | |
with gr.Blocks() as app: | |
gr.Markdown("モデルが見つかりません。`assets/models` フォルダにモデルを配置してください。") | |
return app | |
# 初期表示用のモデルリストをフィルタリング (デフォルトは標準モード) | |
initial_display_models = get_standard_models(all_model_names) | |
initial_model_name = initial_display_models[0] if initial_display_models else None | |
INITIAL_COLUMN_WIDTH = 345 | |
with gr.Blocks(theme=GRADIO_THEME) as app: | |
merged_data_state = gr.State(None) | |
model_name_comps, model_cols = [], [] | |
bulk_slider_comps, bulk_num_comps, bulk_ui_cols, bulk_slider_rows, bulk_num_rows = [], [], [], [], [] | |
voice_slider_comps, voice_pitch_slider_comps, speech_style_slider_comps, tempo_slider_comps = [], [], [], [] | |
voice_num_comps, voice_pitch_num_comps, speech_style_num_comps, tempo_num_comps = [], [], [], [] | |
individual_ui_cols, ind_slider_cols, ind_num_cols = [], [], [] | |
style_rows, style_input_comps, all_style_dropdowns = [], [], [] | |
style_row_input_dropdowns_list, style_row_output_textbox_list = [], [] | |
all_style_and_ratio_cols, all_style_ratio_sliders, style_ratio_rows = [], [], [] | |
with gr.Tabs(): | |
with gr.TabItem("モデル融☆合"): | |
with gr.Row(): | |
model_count_slider = gr.Slider(label="融☆合するモデル数(最大10個)", minimum=1, maximum=MAX_MODELS, step=1, value=2, scale=2) | |
fn_mode_checkbox = gr.Checkbox(label="FNシリーズ/whisperのみ表示", value=False, scale=1) | |
refresh_button = gr.Button("モデルリスト更新", scale=1) | |
equal_ratio_button = gr.Button("比率を平均化", scale=1) | |
with gr.Blocks(): | |
with gr.Row(equal_height=False): | |
for i in range(MAX_MODELS): | |
with gr.Column(scale=0, min_width=INITIAL_COLUMN_WIDTH, visible=i<2) as model_col: | |
gr.Markdown(f"### モデル {i+1}") | |
name = gr.Dropdown(label="モデルフォルダ", choices=initial_display_models, value=initial_model_name if i < 2 else None) | |
with gr.Column(visible=True) as bulk_ui_col: | |
with gr.Row(visible=True) as bulk_slider_row: | |
bulk_slider = gr.Slider(label="比率", value=1.0 if i==0 else 0.0, minimum=0.0, maximum=1.0, step=0.01) | |
with gr.Row(visible=False) as bulk_num_row: | |
bulk_num = gr.Number(label="比率(数値)", value=1.0 if i==0 else 0.0, minimum=0.0, maximum=1.0, step=0.01) | |
with gr.Column(visible=False) as individual_ui_col: | |
with gr.Column(visible=True) as ind_slider_col: | |
voice_slider = gr.Slider(label="声質", value=1.0 if i==0 else 0.0, minimum=0, maximum=1, step=0.01) | |
voice_pitch_slider = gr.Slider(label="高さ", value=1.0 if i==0 else 0.0, minimum=0, maximum=1, step=0.01) | |
speech_style_slider = gr.Slider(label="話し方", value=1.0 if i==0 else 0.0, minimum=0, maximum=1, step=0.01) | |
tempo_slider = gr.Slider(label="速さ", value=1.0 if i==0 else 0.0, minimum=0, maximum=1, step=0.01) | |
with gr.Column(visible=False) as ind_num_col: | |
voice_num = gr.Number(label="声質", value=1.0 if i==0 else 0.0, minimum=0, maximum=1, step=0.01) | |
voice_pitch_num = gr.Number(label="高さ", value=1.0 if i==0 else 0.0, minimum=0, maximum=1, step=0.01) | |
speech_style_num = gr.Number(label="話し方", value=1.0 if i==0 else 0.0, minimum=0, maximum=1, step=0.01) | |
tempo_num = gr.Number(label="速さ", value=1.0 if i==0 else 0.0, minimum=0, maximum=1, step=0.01) | |
model_cols.append(model_col); model_name_comps.append(name) | |
bulk_ui_cols.append(bulk_ui_col); individual_ui_cols.append(individual_ui_col) | |
bulk_slider_rows.append(bulk_slider_row); bulk_num_rows.append(bulk_num_row) | |
bulk_slider_comps.append(bulk_slider); bulk_num_comps.append(bulk_num) | |
ind_slider_cols.append(ind_slider_col); ind_num_cols.append(ind_num_col) | |
voice_slider_comps.append(voice_slider); voice_num_comps.append(voice_num) | |
voice_pitch_slider_comps.append(voice_pitch_slider); voice_pitch_num_comps.append(voice_pitch_num) | |
speech_style_slider_comps.append(speech_style_slider); speech_style_num_comps.append(speech_style_num) | |
tempo_slider_comps.append(tempo_slider); tempo_num_comps.append(tempo_num) | |
all_ratio_sliders = bulk_slider_comps + voice_slider_comps + voice_pitch_slider_comps + speech_style_slider_comps + tempo_slider_comps | |
all_ratio_nums = bulk_num_comps + voice_num_comps + voice_pitch_num_comps + speech_style_num_comps + tempo_num_comps | |
with gr.Row(): | |
model_merge_button = gr.Button("融☆合", variant="primary", scale=1) | |
info_model_merge = gr.Textbox(label="情報", interactive=False) | |
with gr.Accordion("設定", open=False): | |
with gr.Row(): | |
ratio_mode_radio = gr.Radio(label="UIモード", choices=["一括", "個別"], value="一括") | |
input_mode_radio = gr.Radio(label="入力形式", choices=["スライダー", "数値入力"], value="スライダー") | |
column_width_slider = gr.Slider( | |
label="モデルカラムの幅 (px)", | |
minimum=180, | |
maximum=600, | |
value=INITIAL_COLUMN_WIDTH, | |
step=5, | |
interactive=True | |
) | |
gr.Markdown("## 融☆合モデルから音声を生成") | |
with gr.Row(): | |
with gr.Column(variant="panel", scale=1): | |
text_input = gr.TextArea(label="テキスト", value="こんにちは、今日もいい天気ですね。", lines=2) | |
with gr.Row(): | |
style = gr.Dropdown(label="スタイル", choices=[DEFAULT_STYLE], value=DEFAULT_STYLE) | |
emotion_weight = gr.Slider(minimum=0, maximum=20, value=1, step=0.1, label="スタイルの強さ") | |
tts_button = gr.Button("融☆合モデルで読み上げ", variant="primary") | |
tts_info = gr.Textbox(label="情報", interactive=False) | |
audio_output = gr.Audio(label="結果", scale=1) | |
with gr.TabItem("スタイル融☆合"): | |
gr.Markdown("スタイルは各モデルの「話し方」の比率(モデルマージ時に設定)でマージされます。拡張モードで比率の個別設定が可能です。", elem_id="style_merge_info_text") | |
with gr.Column(variant="panel"): | |
with gr.Row(): | |
style_count_slider = gr.Slider(label="作成するスタイル数", value=4, minimum=1, maximum=MAX_STYLES, step=1, scale=3) | |
get_style_btn = gr.Button("各モデルのスタイルを取得", variant="primary", scale=1) | |
extended_style_merge_checkbox = gr.Checkbox( | |
label="各スタイルの比率を個別に設定する", | |
value=False | |
) | |
for i in range(MAX_STYLES): | |
with gr.Column(visible=i<4) as style_row: | |
gr.Markdown(f"#### 新しいスタイル {i+1}") | |
current_row_input_dropdowns = [] | |
current_row_ratio_sliders = [] | |
with gr.Row(equal_height=False): | |
for j in range(MAX_MODELS): | |
with gr.Column(visible=j<2, scale=0, min_width=INITIAL_COLUMN_WIDTH) as style_and_ratio_col: | |
s = gr.Dropdown(label=f"モデル{j+1}の入力スタイル", choices=[(DEFAULT_STYLE, DEFAULT_STYLE)], value=DEFAULT_STYLE) | |
with gr.Row(visible=False) as style_ratio_row: | |
r = gr.Slider( | |
label=f"モデル{j+1}の比率", | |
value=0.0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
interactive=True | |
) | |
current_row_input_dropdowns.append(s) | |
all_style_dropdowns.append(s) | |
current_row_ratio_sliders.append(r) | |
all_style_ratio_sliders.append(r) | |
all_style_and_ratio_cols.append(style_and_ratio_col) | |
style_ratio_rows.append(style_ratio_row) | |
with gr.Column(scale=0, min_width=INITIAL_COLUMN_WIDTH): | |
o = gr.Textbox(label="出力スタイル名", value=DEFAULT_STYLE) | |
style_row_input_dropdowns_list.append(current_row_input_dropdowns) | |
style_row_output_textbox_list.append(o) | |
style_input_comps.extend(current_row_input_dropdowns) | |
style_input_comps.append(o) | |
style_rows.append(style_row) | |
with gr.Row(): | |
style_merge_btn = gr.Button("スタイル融☆合", variant="primary") | |
info_style_merge = gr.Textbox(label="情報", interactive=False) | |
# --- イベントリスナー --- | |
column_width_slider.input( | |
lambda width: [gr.update(min_width=width) for _ in range(MAX_MODELS)] + | |
[gr.update(min_width=width) for _ in range(MAX_STYLES * MAX_MODELS)], | |
inputs=[column_width_slider], | |
outputs=model_cols + all_style_and_ratio_cols | |
) | |
model_count_slider.change( | |
lambda c: [gr.update(visible=i < c) for i in range(MAX_MODELS)] + | |
[gr.update(visible=(i % MAX_MODELS) < c) for i in range(MAX_STYLES * MAX_MODELS)], | |
inputs=[model_count_slider], | |
outputs=model_cols + all_style_and_ratio_cols | |
) | |
ratio_mode_radio.change(lambda mode: [gr.Column(visible=mode == "一括")]*MAX_MODELS + [gr.Column(visible=mode != "一括")]*MAX_MODELS, | |
inputs=[ratio_mode_radio], outputs=bulk_ui_cols + individual_ui_cols) | |
input_mode_radio.change(lambda mode: [gr.Row(visible=mode == "スライダー")]*MAX_MODELS + [gr.Row(visible=mode != "スライダー")]*MAX_MODELS + | |
[gr.Column(visible=mode == "スライダー")]*MAX_MODELS + [gr.Column(visible=mode != "スライダー")]*MAX_MODELS, | |
inputs=[input_mode_radio], outputs=bulk_slider_rows + bulk_num_rows + ind_slider_cols + ind_num_cols) | |
style_count_slider.change(lambda c: [gr.Column(visible=i<c) for i in range(MAX_STYLES)], inputs=[style_count_slider], outputs=style_rows) | |
# ========================================================================= | |
# ★★★★★★★★★★★★★★★★★★★ 修正箇所 ★★★★★★★★★★★★★★★★★★★ | |
# ========================================================================= | |
def refresh_model_list(is_fn_mode: bool, model_count: int, *current_model_names: str): | |
logger.info("モデルリストを更新しています...") | |
new_model_names = [] | |
new_model_files_dict = {} | |
if assets_root.exists() and assets_root.is_dir(): | |
temp_dir_abs_path = Path(DEFAULT_TEMP_SAVE_DIR).resolve() | |
sys_temp_dir_abs_path = Path(tempfile.gettempdir()).resolve() | |
for p in sorted(list(assets_root.iterdir())): | |
if p.is_dir(): | |
if p.name in ["bert", "prompt_histories", "__pycache__"] or p.resolve() == temp_dir_abs_path or p.resolve() == sys_temp_dir_abs_path: | |
continue | |
config_path = p / "config.json" | |
if config_path.exists(): | |
model_name = p.name | |
safetensors_files = sorted(list(p.glob("*.safetensors"))) | |
if safetensors_files: | |
new_model_names.append(model_name) | |
new_model_files_dict[model_name] = [str(f) for f in safetensors_files] | |
model_holder.model_names = new_model_names | |
model_holder.model_files_dict = new_model_files_dict | |
logger.info(f"{len(new_model_names)}個のモデルが見つかりました: {new_model_names}") | |
if is_fn_mode: | |
display_choices = get_fn_models(model_holder.model_names) | |
else: | |
display_choices = get_standard_models(model_holder.model_names) | |
updates = [] | |
for i in range(MAX_MODELS): | |
current_value = current_model_names[i] | |
final_choices = list(display_choices) | |
# 以前選択されていた値が、更新後のモデルリスト全体に存在するか確認 | |
if current_value and current_value in model_holder.model_names: | |
# 存在する場合、現在の表示モードの選択肢になくても追加して選択を維持 | |
if current_value not in final_choices: | |
final_choices.append(current_value) | |
if is_fn_mode: | |
final_choices.sort(key=fn_model_sort_key) | |
else: | |
final_choices.sort() | |
updates.append(gr.update(choices=final_choices, value=current_value)) | |
else: | |
# 存在しない場合 (モデルが削除された等)、リセットする | |
default_model_name = display_choices[0] if display_choices else None | |
is_active = i < model_count | |
updates.append(gr.update( | |
choices=display_choices, | |
value=default_model_name if is_active and default_model_name else None | |
)) | |
return updates | |
refresh_button.click( | |
refresh_model_list, | |
inputs=[fn_mode_checkbox, model_count_slider] + model_name_comps, | |
outputs=model_name_comps | |
) | |
def update_model_choices(is_fn_mode: bool, *current_model_names: str): | |
"""FNモードの切り替えに応じてプルダウンの選択肢を更新する。現在の選択は維持する。""" | |
if is_fn_mode: | |
base_choices = get_fn_models(model_holder.model_names) | |
else: | |
base_choices = get_standard_models(model_holder.model_names) | |
updates = [] | |
for i in range(MAX_MODELS): | |
current_value = current_model_names[i] | |
final_choices = list(base_choices) | |
# 現在の値が存在し、かつベースの選択肢リストにない場合 | |
if current_value and current_value not in final_choices: | |
# ユーザーの選択を維持するために、一時的に選択肢に追加 | |
final_choices.append(current_value) | |
if is_fn_mode: | |
final_choices.sort(key=fn_model_sort_key) | |
else: | |
final_choices.sort() | |
# valueはGradioが自動で維持してくれるので、choicesのみ更新 | |
updates.append(gr.update(choices=final_choices)) | |
return updates | |
fn_mode_checkbox.change( | |
fn=update_model_choices, | |
inputs=[fn_mode_checkbox] + model_name_comps, | |
outputs=model_name_comps, | |
) | |
# ========================================================================= | |
# ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★ | |
equal_ratio_button.click(set_equal_ratio, inputs=[model_count_slider], outputs=all_ratio_sliders + all_ratio_nums) | |
def merge_models_gr_closure( | |
model_count: int, ui_mode: str, input_mode: str, | |
merged_data_state: Optional[dict], *args | |
): | |
return merge_models_gr( | |
model_count, ui_mode, input_mode, merged_data_state, model_holder, *args | |
) | |
model_merge_button.click( | |
merge_models_gr_closure, | |
inputs=[model_count_slider, ratio_mode_radio, input_mode_radio, merged_data_state] + model_name_comps + | |
bulk_slider_comps + bulk_num_comps + | |
voice_slider_comps + voice_pitch_slider_comps + speech_style_slider_comps + tempo_slider_comps + | |
voice_num_comps + voice_pitch_num_comps + speech_style_num_comps + tempo_num_comps, | |
outputs=[info_model_merge, style, merged_data_state] | |
) | |
tts_button.click( | |
simple_tts, | |
inputs=[text_input, style, emotion_weight, merged_data_state], | |
outputs=[tts_info, audio_output] | |
) | |
get_style_btn.click( | |
get_styles_for_all_models, | |
inputs=[model_count_slider] + model_name_comps, | |
outputs=all_style_dropdowns, | |
) | |
for i in range(MAX_STYLES): | |
input_dropdowns_for_row = style_row_input_dropdowns_list[i] | |
output_textbox_for_row = style_row_output_textbox_list[i] | |
for dropdown in input_dropdowns_for_row: | |
dropdown.change( | |
fn=update_default_style_name, | |
inputs=model_name_comps + input_dropdowns_for_row + [model_count_slider], | |
outputs=[output_textbox_for_row] | |
) | |
extended_style_merge_checkbox.change( | |
lambda is_extended: [gr.Row(visible=is_extended)] * len(style_ratio_rows), | |
inputs=[extended_style_merge_checkbox], | |
outputs=style_ratio_rows, | |
).then( | |
set_default_style_ratios, | |
inputs=[extended_style_merge_checkbox, merged_data_state, model_count_slider], | |
outputs=all_style_ratio_sliders | |
) | |
def merge_style_gr_closure(merged_data, model_count, style_count, *args): | |
return merge_style_weighted_sum_gr(merged_data, model_count, style_count, *args) | |
style_merge_btn.click( | |
merge_style_gr_closure, | |
inputs=[ | |
merged_data_state, | |
model_count_slider, | |
style_count_slider, | |
*style_input_comps, | |
extended_style_merge_checkbox, | |
*all_style_ratio_sliders, | |
], | |
outputs=[info_style_merge, style, merged_data_state] | |
) | |
return app | |
if __name__ == "__main__": | |
# アプリケーション起動時に一度だけ一時保存ディレクトリを作成 | |
Path(DEFAULT_TEMP_SAVE_DIR).mkdir(parents=True, exist_ok=True) | |
model_holder = TTSModelHolder(assets_root, device=device) | |
logger.info("初期モデルリストをフィルタリングしています...") | |
original_model_names = model_holder.model_names | |
temp_dir_abs_path = Path(DEFAULT_TEMP_SAVE_DIR).resolve() | |
sys_temp_dir_abs_path = Path(tempfile.gettempdir()).resolve() | |
filtered_model_names = [] | |
for name in original_model_names: | |
model_path = assets_root / name | |
if not model_path.is_dir(): | |
continue | |
model_abs_path = model_path.resolve() | |
if name in ["bert", "prompt_histories", "__pycache__"] or \ | |
model_abs_path == temp_dir_abs_path or \ | |
model_abs_path == sys_temp_dir_abs_path: | |
continue | |
filtered_model_names.append(name) | |
filtered_model_files_dict = { | |
name: files | |
for name, files in model_holder.model_files_dict.items() | |
if name in filtered_model_names | |
} | |
model_holder.model_names = filtered_model_names | |
model_holder.model_files_dict = filtered_model_files_dict | |
logger.info(f"フィルタリング後のモデルリスト: {model_holder.model_names}") | |
app = create_merge_app(model_holder) | |
app.launch(inbrowser=True) |