wadmjada's picture
Upload 43 files
cef6dfd verified
raw
history blame
8.9 kB
# ===================================================================
#
# 【最終成果物】app.py for Hugging Face Spaces
#
# ===================================================================
# --- ライブラリのインポート ---
import torch
import torch.nn as nn
import warnings
import os
import glob
from PIL import Image
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, CLIPProcessor, CLIPModel
from peft import PeftModel
from huggingface_hub import login
# --- 初期設定 ---
warnings.filterwarnings("ignore")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"✅ デバイスを検出: {device}")
# --- グローバル変数とキャッシュ ---
MODELS_CACHE = {}
CURRENT_MODEL_NAME = None
FINAL_MODEL = None
TOKENIZER = None
CLIP_MODEL = None
CLIP_PROCESSOR = None
# -------------------------------------------------------------------
# 1. モデル定義
# -------------------------------------------------------------------
class MultimodalModel(nn.Module):
def __init__(self, base_model, face_embedding_dim):
super().__init__()
self.base_model = base_model
self.face_injector = nn.Linear(face_embedding_dim, self.base_model.config.hidden_size)
# face_injectorはモデルロード時に正しいデバイスとdtypeに設定
self.face_injector.to(device=device, dtype=torch.bfloat16)
def forward(self, input_ids, attention_mask, face_embedding, **kwargs):
inputs_embeds = self.base_model.get_input_embeddings()(input_ids)
target_dtype = self.face_injector.weight.dtype
injected_face_features = self.face_injector(face_embedding.to(dtype=target_dtype))
conditioned_embeds = torch.cat([injected_face_features.unsqueeze(1), inputs_embeds.to(dtype=target_dtype)], dim=1)
new_attention_mask = torch.cat([torch.ones(attention_mask.shape[0], 1, device=attention_mask.device, dtype=attention_mask.dtype), attention_mask], dim=1)
return self.base_model.generate(inputs_embeds=conditioned_embeds, attention_mask=new_attention_mask, **kwargs)
# -------------------------------------------------------------------
# 2. モデルロード関連の関数
# -------------------------------------------------------------------
@torch.no_grad()
def load_base_models_and_tokenizer():
"""アプリケーション起動時に一度だけ、重いベースモデルをロードする"""
global TOKENIZER, CLIP_MODEL, CLIP_PROCESSOR
if TOKENIZER is not None: return
print("🔄 ベースモデルとCLIPモデルをロード中...")
# HFトークンを環境変数から読み込む
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("✅ Hugging Faceにトークンで正常にログインしました。")
TOKENIZER = AutoTokenizer.from_pretrained("rinna/nekomata-7b-instruction", trust_remote_code=True)
if TOKENIZER.pad_token is None: TOKENIZER.pad_token = TOKENIZER.eos_token
CLIP_MODEL = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
CLIP_PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
print("✅ ベースモデル(Tokenizer, CLIP)のロード完了。")
@torch.no_grad()
def load_adapter_model(person_name):
"""ドロップダウンで選択されたAI人格(アダプター)をロードする"""
global FINAL_MODEL, CURRENT_MODEL_NAME
if person_name == CURRENT_MODEL_NAME:
print(f"✅ AI人格「{person_name}」は既にロード済みです。")
return FINAL_MODEL
print(f"🔄 AI人格「{person_name}」をロード中...")
if person_name in MODELS_CACHE:
FINAL_MODEL = MODELS_CACHE[person_name]
CURRENT_MODEL_NAME = person_name
print(f"✅ キャッシュからAI人格「{person_name}」をロードしました。")
return FINAL_MODEL
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
base_llm = AutoModelForCausalLM.from_pretrained("rinna/nekomata-7b-instruction", quantization_config=quantization_config, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
# アダプター(学習パッチ)のパス
adapter_path = f"./final_model_{person_name.replace(' ', '_')}"
peft_model = PeftModel.from_pretrained(base_llm, adapter_path)
face_emb_dim = 768
model = MultimodalModel(peft_model, face_emb_dim)
injector_path = os.path.join(adapter_path, "face_injector.pth")
if os.path.exists(injector_path):
model.face_injector.load_state_dict(torch.load(injector_path, map_location=device))
print("✅ face_injectorの重みをロードしました。")
model.eval()
FINAL_MODEL = model
MODELS_CACHE[person_name] = model
CURRENT_MODEL_NAME = person_name
print(f"✅ AI人格「{person_name}」のロード完了。")
return FINAL_MODEL
# -------------------------------------------------------------------
# 3. Gradioのメインロジック
# -------------------------------------------------------------------
def predict(person_name, image, instruction, max_len, temp, top_p):
if not person_name: raise gr.Error("まず「AI人格」を選択してください。")
if image is None: raise gr.Error("画像をアップロードしてください。")
if not instruction.strip(): raise gr.Error("指示(プロンプト)を入力してください。")
model = load_adapter_model(person_name)
pil_image = Image.fromarray(image)
with torch.no_grad():
inputs = CLIP_PROCESSOR(images=pil_image, return_tensors="pt").to(device)
face_embedding = CLIP_MODEL.get_image_features(**inputs)
prompt = f"USER: {instruction}\nASSISTANT: "
prompt_inputs = TOKENIZER(prompt, return_tensors="pt").to(device)
output_ids = model.forward(
input_ids=prompt_inputs.input_ids,
attention_mask=prompt_inputs.attention_mask,
face_embedding=face_embedding,
max_new_tokens=int(max_len),
do_sample=True, temperature=float(temp), top_p=float(top_p),
pad_token_id=TOKENIZER.pad_token_id,
eos_token_id=TOKENIZER.eos_token_id,
bad_words_ids=[[TOKENIZER.unk_token_id]] if TOKENIZER.unk_token_id is not None else None
)
full_text = TOKENIZER.decode(output_ids[0], skip_special_tokens=True)
assistant_response = full_text.split("ASSISTANT: ")[-1].strip() if "ASSISTANT: " in full_text else full_text
return assistant_response
# -------------------------------------------------------------------
# 4. Gradio UIの構築と起動
# -------------------------------------------------------------------
load_base_models_and_tokenizer()
available_models = [os.path.basename(p).replace("final_model_", "").replace("_", " ") for p in glob.glob("./final_model_*")]
with gr.Blocks(theme=gr.themes.Soft(), title="顔から文章を生成するAI") as demo:
gr.Markdown("<h1>🎨 顔から文章を生成するAI</h1>")
gr.Markdown("①使いたいAIの「人格」を選択し、②顔写真をアップロード、③指示を与えると、AIがその人らしい文章を生成します。")
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(choices=available_models, label="🤖 STEP1: AI人格を選択")
image_input = gr.Image(type="numpy", label="📷 STEP2: 顔写真を入力")
instruction_input = gr.Textbox(lines=3, label="✍️ STEP3: 指示プロンプトを入力", value="あなたらしく、自己紹介をしてください。")
with gr.Accordion("詳細設定", open=False):
max_len_slider = gr.Slider(minimum=30, maximum=300, value=150, step=10, label="最大文字数")
temp_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.9, step=0.05, label="多様性 (Temperature)")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="単語の絞り込み (Top-p)")
submit_btn = gr.Button("文章を生成する", variant="primary")
with gr.Column(scale=2):
text_output = gr.Textbox(lines=17, label="生成された文章", interactive=False)
submit_btn.click(fn=predict, inputs=[model_selector, image_input, instruction_input, max_len_slider, temp_slider, top_p_slider], outputs=text_output, api_name="predict")
# .queue()でリクエストを順番に処理し、タイムアウトを防ぎます
demo.queue().launch()