Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,19 +4,28 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
4 |
import torch
|
5 |
import spaces
|
6 |
|
7 |
-
# PyTorch
|
8 |
torch.backends.cudnn.deterministic = True
|
9 |
torch.backends.cudnn.benchmark = False
|
10 |
torch.backends.cuda.matmul.allow_tf32 = True
|
11 |
|
12 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
@spaces.GPU
|
19 |
def generate_text(
|
|
|
20 |
input_text,
|
21 |
max_length=150,
|
22 |
temperature=0.7,
|
@@ -29,6 +38,9 @@ def generate_text(
|
|
29 |
if not input_text.strip():
|
30 |
return ""
|
31 |
|
|
|
|
|
|
|
32 |
# GPUが利用可能ならGPUへ移動。bf16がサポートされている場合はbf16を使用
|
33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
if device == "cuda" and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
|
@@ -36,11 +48,11 @@ def generate_text(
|
|
36 |
else:
|
37 |
model.to(device)
|
38 |
|
39 |
-
#
|
40 |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
41 |
input_token_count = input_ids.shape[1]
|
42 |
|
43 |
-
# 総トークン数の上限を入力トークン数 + max_length(max_length
|
44 |
total_max_length = input_token_count + max_length
|
45 |
|
46 |
# テキスト生成
|
@@ -56,7 +68,7 @@ def generate_text(
|
|
56 |
num_return_sequences=1
|
57 |
)
|
58 |
|
59 |
-
#
|
60 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
61 |
new_text = generated_text[len(input_text):]
|
62 |
|
@@ -65,29 +77,40 @@ def generate_text(
|
|
65 |
except Exception as e:
|
66 |
return f"{input_text}\n\nエラーが発生しました: {str(e)}"
|
67 |
|
68 |
-
# Gradio
|
69 |
with gr.Blocks() as demo:
|
70 |
gr.Markdown("# テキスト続き生成アシスタント")
|
71 |
-
gr.Markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
#
|
74 |
input_text = gr.Textbox(label="テキストを入力してください", placeholder="ここにテキストを入力...", lines=10)
|
75 |
|
76 |
-
#
|
77 |
max_length_slider = gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="追加するトークン数")
|
78 |
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="創造性(温度)")
|
79 |
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="top_k")
|
80 |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="top_p")
|
81 |
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.1, label="繰り返しペナルティ")
|
82 |
|
83 |
-
#
|
84 |
generate_btn = gr.Button("続きを生成", variant="primary")
|
85 |
clear_btn = gr.Button("クリア")
|
86 |
|
87 |
-
#
|
88 |
generate_btn.click(
|
89 |
fn=generate_text,
|
90 |
-
inputs=[input_text, max_length_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
|
91 |
outputs=input_text
|
92 |
)
|
93 |
|
@@ -96,17 +119,18 @@ with gr.Blocks() as demo:
|
|
96 |
# 使い方の説明
|
97 |
gr.Markdown("""
|
98 |
## 使い方
|
99 |
-
1.
|
100 |
-
2.
|
101 |
-
3.
|
102 |
-
4.
|
103 |
-
5.
|
|
|
104 |
|
105 |
## ヒント
|
106 |
-
-
|
107 |
- 創造性(温度)を高くすると予測不可能な生成結果に、低くすると安定した結果になります
|
108 |
-
- top_k や top_p
|
109 |
""")
|
110 |
|
111 |
# アプリの起動
|
112 |
-
demo.launch()
|
|
|
4 |
import torch
|
5 |
import spaces
|
6 |
|
7 |
+
# PyTorch設定(パフォーマンスと再現性向上のため)
|
8 |
torch.backends.cudnn.deterministic = True
|
9 |
torch.backends.cudnn.benchmark = False
|
10 |
torch.backends.cuda.matmul.allow_tf32 = True
|
11 |
|
12 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
13 |
|
14 |
+
# モデルのキャッシュ用辞書(ロード済みなら再利用)
|
15 |
+
loaded_models = {}
|
16 |
+
|
17 |
+
def get_model_and_tokenizer(model_name):
|
18 |
+
if model_name in loaded_models:
|
19 |
+
return loaded_models[model_name]
|
20 |
+
else:
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN)
|
22 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN)
|
23 |
+
loaded_models[model_name] = (model, tokenizer)
|
24 |
+
return model, tokenizer
|
25 |
|
26 |
@spaces.GPU
|
27 |
def generate_text(
|
28 |
+
model_name,
|
29 |
input_text,
|
30 |
max_length=150,
|
31 |
temperature=0.7,
|
|
|
38 |
if not input_text.strip():
|
39 |
return ""
|
40 |
|
41 |
+
# 選択されたモデルとトークナイザーを取得
|
42 |
+
model, tokenizer = get_model_and_tokenizer(model_name)
|
43 |
+
|
44 |
# GPUが利用可能ならGPUへ移動。bf16がサポートされている場合はbf16を使用
|
45 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
46 |
if device == "cuda" and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
|
|
|
48 |
else:
|
49 |
model.to(device)
|
50 |
|
51 |
+
# 入力テキストのトークン化
|
52 |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
53 |
input_token_count = input_ids.shape[1]
|
54 |
|
55 |
+
# 総トークン数の上限を入力トークン数 + max_length(max_lengthは追加するトークン数として扱う)
|
56 |
total_max_length = input_token_count + max_length
|
57 |
|
58 |
# テキスト生成
|
|
|
68 |
num_return_sequences=1
|
69 |
)
|
70 |
|
71 |
+
# 生成されたテキストをデコードし、入力部分を除いた生成分を抽出
|
72 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
73 |
new_text = generated_text[len(input_text):]
|
74 |
|
|
|
77 |
except Exception as e:
|
78 |
return f"{input_text}\n\nエラーが発生しました: {str(e)}"
|
79 |
|
80 |
+
# Gradioインターフェースの作成
|
81 |
with gr.Blocks() as demo:
|
82 |
gr.Markdown("# テキスト続き生成アシスタント")
|
83 |
+
gr.Markdown("モデルを選択し、テキストボックスに文章を入力してパラメータを調整後、「続きを生成」ボタンをクリックすると、選択したモデルがその続きを生成します。")
|
84 |
+
|
85 |
+
# モデル選択用プルダウンメニュー(候補は必要に応じて変更してください)
|
86 |
+
model_dropdown = gr.Dropdown(
|
87 |
+
choices=[
|
88 |
+
"Local-Novel-LLM-project/Vecteus-v1-abliterated",
|
89 |
+
"gpt2",
|
90 |
+
"EleutherAI/gpt-neo-125M"
|
91 |
+
],
|
92 |
+
label="モデルを選択してください",
|
93 |
+
value="Local-Novel-LLM-project/Vecteus-v1-abliterated"
|
94 |
+
)
|
95 |
|
96 |
+
# テキスト入力ボックス
|
97 |
input_text = gr.Textbox(label="テキストを入力してください", placeholder="ここにテキストを入力...", lines=10)
|
98 |
|
99 |
+
# 生成パラメータの設定UI(縦に一列で配置)
|
100 |
max_length_slider = gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="追加するトークン数")
|
101 |
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="創造性(温度)")
|
102 |
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="top_k")
|
103 |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="top_p")
|
104 |
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.1, label="繰り返しペナルティ")
|
105 |
|
106 |
+
# ボタン(縦に配置)
|
107 |
generate_btn = gr.Button("続きを生成", variant="primary")
|
108 |
clear_btn = gr.Button("クリア")
|
109 |
|
110 |
+
# イベントの設定:入力としてモデル選択とテキスト、パラメータを渡す
|
111 |
generate_btn.click(
|
112 |
fn=generate_text,
|
113 |
+
inputs=[model_dropdown, input_text, max_length_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
|
114 |
outputs=input_text
|
115 |
)
|
116 |
|
|
|
119 |
# 使い方の説明
|
120 |
gr.Markdown("""
|
121 |
## 使い方
|
122 |
+
1. 上部のプルダウンメニューから使用するモデルを選択します
|
123 |
+
2. テキストボックスに続きを生成したい文章を入力します
|
124 |
+
3. 生成パラメータ(追加するトークン数、創造性、top_k、top_p、繰り返しペナルティ)を調整します
|
125 |
+
4. 「続きを生成」ボタンをクリックすると、入力したテキストの続きが生成され、元のテキストに追加されます
|
126 |
+
5. 「クリア」ボタンを押すと、テキストボックスの内容がクリアされます
|
127 |
+
6. 満足のいく結果が得られるまで、繰り返し「続きを生成」ボタンを押して文章を発展させることができます
|
128 |
|
129 |
## ヒント
|
130 |
+
- 短い文章から始め、徐々に発展させると良い結果が得られます
|
131 |
- 創造性(温度)を高くすると予測不可能な生成結果に、低くすると安定した結果になります
|
132 |
+
- top_k や top_p、繰り返しペナルティも状況に応じて調整してみてください
|
133 |
""")
|
134 |
|
135 |
# アプリの起動
|
136 |
+
demo.launch()
|