TextGeneration / app.py
2z299's picture
Update app.py
c14f2e8 verified
raw
history blame
6.93 kB
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces
# PyTorch設定(パフォーマンスと再現性向上のため)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
HF_TOKEN = os.getenv("HF_TOKEN")
# モデルのキャッシュ用辞書(ロード済みなら再利用)
loaded_models = {}
def get_model_and_tokenizer(model_name):
# 既にロード済みならそのまま返す
if model_name in loaded_models:
return loaded_models[model_name]
# ロードされていなければロードする
tokenizer = AutoTokenizer.from_pretrained(
model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN)
loaded_models[model_name] = (model, tokenizer)
return model, tokenizer
def disable_generate_button():
# 生成ボタンを無効化し、テキストを「モデルをロード中……」に変更する
return gr.update(interactive=False, value="モデルをロード中……")
def load_model(model_name):
"""
プルダウン変更時や起動時に呼ばれ、モデルをロードして生成ボタンを有効化する。
"""
tokenizer = AutoTokenizer.from_pretrained(
model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN)
loaded_models[model_name] = (model, tokenizer)
status_message = f"Model '{model_name}' loaded successfully."
# ロード完了後、生成ボタンを有効化し、テキストを「続きを生成」に戻す
return status_message, gr.update(interactive=True, value="続きを生成")
@spaces.GPU
def generate_text(
model_name,
input_text,
max_length=150,
temperature=0.7,
top_k=50,
top_p=0.95,
repetition_penalty=1.0
):
"""ユーザー入力に基づいてテキストを生成し、元のテキストに追加する関数"""
try:
if not input_text.strip():
return ""
# 既にロード済みのモデルとトークナイザーを使用
model, tokenizer = get_model_and_tokenizer(model_name)
# GPUが利用可能ならGPUへ移動。bf16がサポートされている場合はbf16を使用
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda" and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
model.to(device, dtype=torch.bfloat16)
else:
model.to(device)
# 入力テキストのトークン化
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
input_token_count = input_ids.shape[1]
# 総トークン数の上限を入力トークン数 + max_length(max_lengthは追加するトークン数として扱う)
total_max_length = input_token_count + max_length
# テキスト生成
output_ids = model.generate(
input_ids,
max_length=total_max_length,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
# 生成されたテキストをデコードし、入力部分を除いた生成分を抽出
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
new_text = generated_text[len(input_text):]
# 入力テキストに生成したテキストを追加して返す
return input_text + new_text
except Exception as e:
return f"{input_text}\n\nエラーが発生しました: {str(e)}"
# Gradioインターフェースの作成
with gr.Blocks() as demo:
gr.Markdown("# テキスト続き生成アシスタント")
gr.Markdown("モデルを選択し、テキストボックスに文章を入力してパラメータを調整後、「続きを生成」ボタンをクリックすると、選択したモデルがその続きを生成します。")
# モデル選択用プルダウンメニュー
model_dropdown = gr.Dropdown(
choices=[
"Local-Novel-LLM-project/Vecteus-v1-abliterated",
"Local-Novel-LLM-project/Ninja-V3",
"Local-Novel-LLM-project/kagemusya-7B-v1"
],
label="モデルを選択してください",
value="Local-Novel-LLM-project/Vecteus-v1-abliterated"
)
# 隠しコンポーネント:モデルロードの状況を表示(ユーザーには見せなくても良い)
load_status = gr.Textbox(visible=False)
# テキスト入力ボックス
input_text = gr.Textbox(label="テキストを入力してください", placeholder="ここにテキストを入力...", lines=10)
# 生成パラメータの設定UI
max_length_slider = gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="追加するトークン数")
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="創造性(温度)")
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="top_k")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="top_p")
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.1, label="繰り返しペナルティ")
# 生成ボタンは初期状態で無効化
generate_btn = gr.Button("モデルをロード中……", variant="primary", interactive=False)
clear_btn = gr.Button("クリア")
# プルダウン変更時に、まず生成ボタンを無効化(テキストを「モデルをロード中……」に変更)し、その後モデルをロードして生成ボタンを再有効化するイベントチェーンを設定
model_dropdown.change(
fn=disable_generate_button,
inputs=None,
outputs=generate_btn
).then(
fn=load_model,
inputs=model_dropdown,
outputs=[load_status, generate_btn]
)
# 起動時にも load_model を実行する(初期値のモデルでロード)
demo.load(fn=load_model, inputs=model_dropdown, outputs=[load_status, generate_btn])
# 生成ボタン押下時のイベント設定
generate_btn.click(
fn=generate_text,
inputs=[model_dropdown, input_text, max_length_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
outputs=input_text
)
clear_btn.click(lambda: "", None, input_text)
# アプリの起動
demo.launch()