Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import spaces | |
# PyTorch設定(パフォーマンスと再現性向上のため) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cuda.matmul.allow_tf32 = True | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# モデルのキャッシュ用辞書(ロード済みなら再利用) | |
loaded_models = {} | |
def get_model_and_tokenizer(model_name): | |
if model_name in loaded_models: | |
return loaded_models[model_name] | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN) | |
loaded_models[model_name] = (model, tokenizer) | |
return model, tokenizer | |
def generate_text( | |
model_name, | |
input_text, | |
max_length=150, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
repetition_penalty=1.0 | |
): | |
"""ユーザー入力に基づいてテキストを生成し、元のテキストに追加する関数""" | |
try: | |
if not input_text.strip(): | |
return "" | |
# 選択されたモデルとトークナイザーを取得 | |
model, tokenizer = get_model_and_tokenizer(model_name) | |
# GPUが利用可能ならGPUへ移動。bf16がサポートされている場合はbf16を使用 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if device == "cuda" and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported(): | |
model.to(device, dtype=torch.bfloat16) | |
else: | |
model.to(device) | |
# 入力テキストのトークン化 | |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
input_token_count = input_ids.shape[1] | |
# 総トークン数の上限を入力トークン数 + max_length(max_lengthは追加するトークン数として扱う) | |
total_max_length = input_token_count + max_length | |
# テキスト生成 | |
output_ids = model.generate( | |
input_ids, | |
max_length=total_max_length, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1 | |
) | |
# 生成されたテキストをデコードし、入力部分を除いた生成分を抽出 | |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
new_text = generated_text[len(input_text):] | |
# 入力テキストに生成したテキストを追加して返す | |
return input_text + new_text | |
except Exception as e: | |
return f"{input_text}\n\nエラーが発生しました: {str(e)}" | |
# Gradioインターフェースの作成 | |
with gr.Blocks() as demo: | |
gr.Markdown("# テキスト続き生成アシスタント") | |
gr.Markdown("モデルを選択し、テキストボックスに文章を入力してパラメータを調整後、「続きを生成」ボタンをクリックすると、選択したモデルがその続きを生成します。") | |
# モデル選択用プルダウンメニュー(候補は必要に応じて変更してください) | |
model_dropdown = gr.Dropdown( | |
choices=[ | |
"Local-Novel-LLM-project/Vecteus-v1-abliterated", | |
"Local-Novel-LLM-project/Ninja-V3", | |
"Local-Novel-LLM-project/kagemusya-7B-v1" | |
], | |
label="モデルを選択してください", | |
value="Local-Novel-LLM-project/Vecteus-v1-abliterated" | |
) | |
# テキスト入力ボックス | |
input_text = gr.Textbox(label="テキストを入力してください", placeholder="ここにテキストを入力...", lines=10) | |
# 生成パラメータの設定UI(縦に一列で配置) | |
max_length_slider = gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="追加するトークン数") | |
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="創造性(温度)") | |
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="top_k") | |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="top_p") | |
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.1, label="繰り返しペナルティ") | |
# ボタン(縦に配置) | |
generate_btn = gr.Button("続きを生成", variant="primary") | |
clear_btn = gr.Button("クリア") | |
# イベントの設定:入力としてモデル選択とテキスト、パラメータを渡す | |
generate_btn.click( | |
fn=generate_text, | |
inputs=[model_dropdown, input_text, max_length_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], | |
outputs=input_text | |
) | |
clear_btn.click(lambda: "", None, input_text) | |
# 使い方の説明 | |
gr.Markdown(""" | |
## 使い方 | |
1. 上部のプルダウンメニューから使用するモデルを選択します | |
2. テキストボックスに続きを生成したい文章を入力します | |
3. 生成パラメータ(追加するトークン数、創造性、top_k、top_p、繰り返しペナルティ)を調整します | |
4. 「続きを生成」ボタンをクリックすると、入力したテキストの続きが生成され、元のテキストに追加されます | |
5. 「クリア」ボタンを押すと、テキストボックスの内容がクリアされます | |
6. 満足のいく結果が得られるまで、繰り返し「続きを生成」ボタンを押して文章を発展させることができます | |
## ヒント | |
- 短い文章から始め、徐々に発展させると良い結果が得られます | |
- 創造性(温度)を高くすると予測不可能な生成結果に、低くすると安定した結果になります | |
- top_k や top_p、繰り返しペナルティも状況に応じて調整してみてください | |
""") | |
# アプリの起動 | |
demo.launch() |