wadmjada commited on
Commit
a263833
·
verified ·
1 Parent(s): 33640cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -102
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # ===================================================================
2
  #
3
- #   【最終修正版v2】AIペルソナ選択式 Gradioアプリ (環境自動判定対応)
4
  #
5
  # ===================================================================
6
 
@@ -26,40 +26,32 @@ warnings.filterwarnings("ignore")
26
 
27
  # --- グローバル変数 ---
28
  MODELS_CACHE = {}
29
- CURRENT_MODEL_NAME = None
30
- FINAL_MODEL = None
31
  TOKENIZER = None
32
  CLIP_MODEL = None
33
  CLIP_PROCESSOR = None
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
  # ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
37
- # ★★★ エラー修正箇所 ★★★
38
- # ★★★ 実行環境を自動で判定し、モデルへのパスを正しく設定します ★★★
39
  # ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
40
  if 'SPACE_ID' in os.environ:
41
  print("✅ Hugging Face Spaces環境で実行中")
42
- # Hugging Face Spacesでは、アップロードされたファイルはルートにあります
43
- DRIVE_BASE_PATH = '.'
44
- IS_COLAB = False
45
  else:
46
  try:
47
  from google.colab import drive
48
  print("\n🔗 Googleドライブをマウントします...")
49
  drive.mount('/content/drive', force_remount=True)
50
- DRIVE_BASE_PATH = '/content/drive/MyDrive'
51
- IS_COLAB = True
52
  print("✅ Google Colab環境で実行中")
53
  except ImportError:
54
  print("⚠️ ローカル環境として実行します。'./models'フォルダにモデルを配置してください。")
55
- # ローカルPCで実行する場合、このコードと同じ階層に'models'フォルダを作成し、
56
- # その中に`final_model_`で始まるフォルダ群を配置してください。
57
- DRIVE_BASE_PATH = './models'
58
- IS_COLAB = False
59
 
60
  print(f"🖥️ 使用デバイス: {DEVICE}")
61
- print(f"📂 モデルデータの検索パス: {DRIVE_BASE_PATH}")
62
-
63
 
64
  # --- モデル定義 ---
65
  class MultimodalModel(nn.Module):
@@ -98,78 +90,63 @@ class MultimodalModel(nn.Module):
98
 
99
  # --- モデルロード関連の関数 ---
100
  @torch.no_grad()
101
- def get_available_models():
102
- """利用可能なモデルのリストを取得"""
 
103
  try:
104
- if not os.path.exists(DRIVE_BASE_PATH):
105
- logger.warning(f"モデルディレクトリが存在しません: {DRIVE_BASE_PATH}")
106
- return []
107
-
108
- model_paths = glob.glob(os.path.join(DRIVE_BASE_PATH, "final_model_*"))
109
- if not model_paths:
110
- logger.warning(f"モデルファイルが見つかりません: {DRIVE_BASE_PATH}に final_model_* がありません。")
111
- return []
112
-
113
- model_names = sorted([os.path.basename(p).replace("final_model_", "").replace("_", " ") for p in model_paths])
114
- return model_names
115
- except Exception as e:
116
- logger.error(f"モデル検索エラー: {e}")
117
- return []
118
-
119
- @torch.no_grad()
120
- def load_model_and_dependencies(person_name):
121
- """全ての依存モデルと指定された人物モデルをロード"""
122
- global FINAL_MODEL, CURRENT_MODEL_NAME, TOKENIZER, CLIP_MODEL, CLIP_PROCESSOR
123
 
124
- try:
125
- # --- 共通モデル(Tokenizer, CLIP)の初期化 ---
126
  if TOKENIZER is None:
127
  logger.info("📝 Tokenizerをロード中...")
128
  TOKENIZER = AutoTokenizer.from_pretrained("rinna/nekomata-7b-instruction", trust_remote_code=True)
129
- if TOKENIZER.pad_token is None:
130
- TOKENIZER.pad_token = TOKENIZER.eos_token
131
  logger.info("✅ Tokenizerのロード完了")
132
 
133
- if CLIP_MODEL is None or CLIP_PROCESSOR is None:
134
  logger.info("📷 CLIP画像エンコーダーをロード中...")
135
  CLIP_MODEL = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)
136
  CLIP_PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
137
  logger.info("✅ CLIPモデルのロード完了")
 
 
 
 
 
138
 
139
- # --- ペルソナモデルのロード ---
140
- if person_name == CURRENT_MODEL_NAME and FINAL_MODEL is not None:
141
- logger.info(f"✅ モデル「{person_name}」は既にロード済みです。")
 
 
 
 
 
142
  return True
143
 
144
- logger.info(f"🔄 モデル「{person_name}」をロード中...")
145
-
146
  if person_name in MODELS_CACHE:
147
  FINAL_MODEL = MODELS_CACHE[person_name]
148
- CURRENT_MODEL_NAME = person_name
149
- logger.info(f"✅ キャッシュからモデル「{person_name}」をロードしました。")
150
  return True
151
 
152
- quantization_config = BitsAndBytesConfig(
153
- load_in_4bit=True,
154
- bnb_4bit_quant_type="nf4",
155
- bnb_4bit_compute_dtype=torch.bfloat16,
156
- bnb_4bit_use_double_quant=True
157
- )
158
-
159
- base_model = AutoModelForCausalLM.from_pretrained(
160
- "rinna/nekomata-7b-instruction",
161
- quantization_config=quantization_config,
162
- torch_dtype=torch.bfloat16,
163
- trust_remote_code=True,
164
- device_map="auto"
165
- )
166
-
167
- adapter_path = os.path.join(DRIVE_BASE_PATH, f"final_model_{person_name.replace(' ', '_')}")
168
  if not os.path.exists(adapter_path):
169
  logger.error(f"モデルパスが存在しません: {adapter_path}")
170
  return False
171
 
172
- peft_model = PeftModel.from_pretrained(base_model, adapter_path)
 
173
 
174
  face_emb_dim = 768
175
  model = MultimodalModel(peft_model, face_emb_dim)
@@ -177,7 +154,7 @@ def load_model_and_dependencies(person_name):
177
  injector_path = os.path.join(adapter_path, "face_injector.pth")
178
  if os.path.exists(injector_path):
179
  model.face_injector.load_state_dict(torch.load(injector_path, map_location=DEVICE))
180
- logger.info("✅ face_injectorの重みをロードしました。")
181
  else:
182
  logger.warning(f"⚠️ face_injectorの重みファイルが見つかりません: {injector_path}")
183
  return False
@@ -185,12 +162,11 @@ def load_model_and_dependencies(person_name):
185
  model.eval()
186
  FINAL_MODEL = model
187
  MODELS_CACHE[person_name] = model
188
- CURRENT_MODEL_NAME = person_name
189
- logger.info(f"✅ モデル「{person_name}」のロード完了。")
190
  return True
191
-
192
  except Exception as e:
193
- logger.error(f"モデルロードおよび依存関係の初期化エラー: {e}")
194
  traceback.print_exc()
195
  return False
196
 
@@ -198,18 +174,14 @@ def load_model_and_dependencies(person_name):
198
  def predict(person_name, image, instruction, max_len, temp, top_p, progress=gr.Progress()):
199
  """メイン予測関数"""
200
  try:
201
- progress(0, desc="🔄 AI人格モデルの準備中...")
202
- if person_name is None:
203
- return "❌ エラー: まず「AI人格」を選択してください。"
204
- if image is None:
205
- return "❌ エラー: 画像をアップロードしてください。"
206
- if not instruction.strip():
207
- return "❌ エラー: 指示(プロンプト)を入力してください。"
208
 
209
- if not load_model_and_dependencies(person_name):
210
- return f"❌ エラー: モデル「{person_name}」のロードに失敗しました。"
211
-
212
- progress(0.3, desc="🖼️ 画像を解析中...")
213
  pil_image = image.convert("RGB")
214
 
215
  with torch.no_grad():
@@ -240,7 +212,7 @@ def predict(person_name, image, instruction, max_len, temp, top_p, progress=gr.P
240
  return assistant_response
241
 
242
  except Exception as e:
243
- error_msg = f"❌ 予測エラー: {str(e)}"
244
  logger.error(error_msg)
245
  traceback.print_exc()
246
  return error_msg
@@ -254,15 +226,11 @@ def create_gradio_interface():
254
  if not available_models:
255
  with gr.Blocks(title="エラー") as demo:
256
  gr.Markdown("# ❌ 起動エラー: モデルが見つかりません")
257
- gr.Markdown(f"AI人格のモデルデータ(`final_model_`で始まるフォルダ)が見つかりませんでした。\n\n**プログラムが探した場所:** `{os.path.abspath(DRIVE_BASE_PATH)}`\n\nこの場所にモデルデータを正しく配置してください。")
258
  return demo
259
 
260
- custom_css = """
261
- .gradio-container { max-width: 1200px !important; margin: auto; }
262
- .gr-button { background: linear-gradient(45deg, #667eea, #764ba2) !important; border: none !important; color: white !important; font-weight: bold !important; border-radius: 10px !important; padding: 15px 30px !important; }
263
- .header-text { text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 20px; }
264
- """
265
-
266
  with gr.Blocks(css=custom_css, title="AIペルソナ文章生成") as demo:
267
  gr.HTML("""
268
  <div class="header-text">
@@ -289,7 +257,7 @@ def create_gradio_interface():
289
  gr.Markdown("### 📝 STEP 2: 生成結果")
290
  text_output = gr.Textbox(label="生成された文章", lines=20, interactive=False, placeholder="ここに生成された文章が表示されます...", show_copy_button=True)
291
 
292
- gr.Markdown("--- \n ### 💡 ヒント\n- **画像**: 顔がはっきり写った写真を使用してください\n- **指示**: 具体的で明確な指示を与えると良い結果が得られます\n- **設定**: 創造性を上げると面白い文章、下げると安定した文章が生成されます")
293
 
294
  submit_btn.click(
295
  fn=predict,
@@ -304,15 +272,14 @@ def create_gradio_interface():
304
  traceback.print_exc()
305
  return None
306
 
307
- # --- メイン実行 ---
308
- def main():
309
- logger.info("🌟 アプリケーション開始")
310
-
311
- demo = create_gradio_interface()
312
- if demo:
313
- logger.info("🌐 Gradioアプリケーションを起動します...")
314
- # Hugging Face Spacesではshare=Trueは不要で、自動的に公開されます
315
- demo.launch(debug=False)
316
-
317
  if __name__ == "__main__":
318
- main()
 
 
 
 
 
 
 
 
 
1
  # ===================================================================
2
  #
3
+ #   【最終版v3】AIペルソナ選択式 Gradioアプリ (メモリ効率化・安定版)
4
  #
5
  # ===================================================================
6
 
 
26
 
27
  # --- グローバル変数 ---
28
  MODELS_CACHE = {}
29
+ CURRENT_PERSONA = None
30
+ BASE_MODEL = None # ベースモデルをグローバルに保持
31
  TOKENIZER = None
32
  CLIP_MODEL = None
33
  CLIP_PROCESSOR = None
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
  # ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
37
+ # ★★★ 環境判定とパス設定 ★★★
 
38
  # ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
39
  if 'SPACE_ID' in os.environ:
40
  print("✅ Hugging Face Spaces環境で実行中")
41
+ MODEL_BASE_PATH = '.' # スペースのルートディレクトリ
 
 
42
  else:
43
  try:
44
  from google.colab import drive
45
  print("\n🔗 Googleドライブをマウントします...")
46
  drive.mount('/content/drive', force_remount=True)
47
+ MODEL_BASE_PATH = '/content/drive/MyDrive'
 
48
  print("✅ Google Colab環境で実行中")
49
  except ImportError:
50
  print("⚠️ ローカル環境として実行します。'./models'フォルダにモデルを配置してください。")
51
+ MODEL_BASE_PATH = './models'
 
 
 
52
 
53
  print(f"🖥️ 使用デバイス: {DEVICE}")
54
+ print(f"📂 モデルデータの検索パス: {MODEL_BASE_PATH}")
 
55
 
56
  # --- モデル定義 ---
57
  class MultimodalModel(nn.Module):
 
90
 
91
  # --- モデルロード関連の関数 ---
92
  @torch.no_grad()
93
+ def initialize_core_models():
94
+ """アプリ起動時に一度だけ、重いモデルをロードする"""
95
+ global BASE_MODEL, TOKENIZER, CLIP_MODEL, CLIP_PROCESSOR
96
  try:
97
+ if BASE_MODEL is None:
98
+ logger.info("🔄 [初回起動] ベースLLMをロード中... (時間がかかります)")
99
+ quantization_config = BitsAndBytesConfig(
100
+ load_in_4bit=True, bnb_4bit_quant_type="nf4",
101
+ bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True
102
+ )
103
+ BASE_MODEL = AutoModelForCausalLM.from_pretrained(
104
+ "rinna/nekomata-7b-instruction",
105
+ quantization_config=quantization_config, torch_dtype=torch.bfloat16,
106
+ trust_remote_code=True, device_map="auto"
107
+ )
108
+ logger.info("✅ ベースLLMのロード完了")
 
 
 
 
 
 
 
109
 
 
 
110
  if TOKENIZER is None:
111
  logger.info("📝 Tokenizerをロード中...")
112
  TOKENIZER = AutoTokenizer.from_pretrained("rinna/nekomata-7b-instruction", trust_remote_code=True)
113
+ TOKENIZER.pad_token = TOKENIZER.eos_token
 
114
  logger.info("✅ Tokenizerのロード完了")
115
 
116
+ if CLIP_MODEL is None:
117
  logger.info("📷 CLIP画像エンコーダーをロード中...")
118
  CLIP_MODEL = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)
119
  CLIP_PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
120
  logger.info("✅ CLIPモデルのロード完了")
121
+ return True
122
+ except Exception as e:
123
+ logger.error(f"コアモデルの初期化エラー: {e}")
124
+ traceback.print_exc()
125
+ return False
126
 
127
+ @torch.no_grad()
128
+ def load_persona_model(person_name, progress=gr.Progress()):
129
+ """指定されたAI人格(アダプター)をロードする"""
130
+ global FINAL_MODEL, CURRENT_PERSONA
131
+ try:
132
+ progress(0.1, desc=f"🔄 AI人格「{person_name}」を準備中...")
133
+ if person_name == CURRENT_PERSONA and FINAL_MODEL is not None:
134
+ logger.info(f"✅ AI人格「{person_name}」は既に準備完了です。")
135
  return True
136
 
 
 
137
  if person_name in MODELS_CACHE:
138
  FINAL_MODEL = MODELS_CACHE[person_name]
139
+ CURRENT_PERSONA = person_name
140
+ logger.info(f"✅ キャッシュからAI人格「{person_name}」をロードしました。")
141
  return True
142
 
143
+ adapter_path = os.path.join(MODEL_BASE_PATH, f"final_model_{person_name.replace(' ', '_')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  if not os.path.exists(adapter_path):
145
  logger.error(f"モデルパスが存在しません: {adapter_path}")
146
  return False
147
 
148
+ logger.info(f"🧠 アダプターをロード中: {adapter_path}")
149
+ peft_model = PeftModel.from_pretrained(BASE_MODEL, adapter_path)
150
 
151
  face_emb_dim = 768
152
  model = MultimodalModel(peft_model, face_emb_dim)
 
154
  injector_path = os.path.join(adapter_path, "face_injector.pth")
155
  if os.path.exists(injector_path):
156
  model.face_injector.load_state_dict(torch.load(injector_path, map_location=DEVICE))
157
+ logger.info("✅ 顔特徴インジェクターの重みをロードしました。")
158
  else:
159
  logger.warning(f"⚠️ face_injectorの重みファイルが見つかりません: {injector_path}")
160
  return False
 
162
  model.eval()
163
  FINAL_MODEL = model
164
  MODELS_CACHE[person_name] = model
165
+ CURRENT_PERSONA = person_name
166
+ logger.info(f"✅ AI人格「{person_name}」の準備完了。")
167
  return True
 
168
  except Exception as e:
169
+ logger.error(f"AI人格のロードエラー: {e}")
170
  traceback.print_exc()
171
  return False
172
 
 
174
  def predict(person_name, image, instruction, max_len, temp, top_p, progress=gr.Progress()):
175
  """メイン予測関数"""
176
  try:
177
+ if person_name is None: return " エラー: まず「AI人格」を選択してください。"
178
+ if image is None: return "❌ エラー: 画像をアップロードしてください。"
179
+ if not instruction.strip(): return "❌ エラー: 指示(プロンプト)を入力してください。"
 
 
 
 
180
 
181
+ if not load_persona_model(person_name, progress):
182
+ return f"❌ エラー: AI人格「{person_name}」のロードに失敗しました。"
183
+
184
+ progress(0.4, desc="🖼️ 顔の雰囲気を分析中...")
185
  pil_image = image.convert("RGB")
186
 
187
  with torch.no_grad():
 
212
  return assistant_response
213
 
214
  except Exception as e:
215
+ error_msg = f"❌ 予測中にエラーが発生しました: {str(e)}"
216
  logger.error(error_msg)
217
  traceback.print_exc()
218
  return error_msg
 
226
  if not available_models:
227
  with gr.Blocks(title="エラー") as demo:
228
  gr.Markdown("# ❌ 起動エラー: モデルが見つかりません")
229
+ gr.Markdown(f"AI人格のモデルデータ(`final_model_`で始まるフォルダ)が見つかりませんでした。\n\n**プログラムが探した場所:** `{os.path.abspath(MODEL_BASE_PATH)}`\n\nこの場所にモデルデータを正しく配置してください。")
230
  return demo
231
 
232
+ custom_css = "..." # (CSSは省略)
233
+
 
 
 
 
234
  with gr.Blocks(css=custom_css, title="AIペルソナ文章生成") as demo:
235
  gr.HTML("""
236
  <div class="header-text">
 
257
  gr.Markdown("### 📝 STEP 2: 生成結果")
258
  text_output = gr.Textbox(label="生成された文章", lines=20, interactive=False, placeholder="ここに生成された文章が表示されます...", show_copy_button=True)
259
 
260
+ demo.load(lambda: " アプリ準備完了!AI人格を選択して開始してください。", [], text_output)
261
 
262
  submit_btn.click(
263
  fn=predict,
 
272
  traceback.print_exc()
273
  return None
274
 
275
+ # --- メイン実行ブロック ---
 
 
 
 
 
 
 
 
 
276
  if __name__ == "__main__":
277
+ logger.info("🌟 アプリケーション起動プロセス開始")
278
+
279
+ if initialize_core_models():
280
+ demo = create_gradio_interface()
281
+ if demo:
282
+ logger.info("🌐 Gradioアプリケーションを起動します...")
283
+ demo.launch(debug=False) # share=TrueはHugging Face Spacesでは不要
284
+ else:
285
+ logger.error("❌ コアモデルの初期化に失敗したため、アプリを起動できません。")