TransKP / app.py
KangjieXu's picture
Update app.py
90be0b9 verified
import os
import re
import io
import pandas as pd
from datetime import datetime
import torch
from flask import Flask, request, jsonify, render_template, Response
from transformers import AutoTokenizer
from huggingface_hub import HfApi, HfFolder
from werkzeug.utils import secure_filename
from model import DeepFusionKcatPredictor
# --- 1. 初始化和配置 ---
app = Flask(__name__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
ESM_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
DATASET_REPO_ID = "KangjieXu/TransKP-usage-logs" # 日志将被上传到这个数据集仓库
HF_TOKEN = os.getenv("HF_TOKEN")
# 初始化Hugging Face Hub API客户端
if HF_TOKEN:
HfFolder.save_token(HF_TOKEN)
api = HfApi()
print("Hugging Face Hub API客户端初始化成功。")
# --- 2. 模型加载 ---
MODEL_REPO_ID = "KangjieXu/TransKP-model"
MODEL_FILENAME = "deep_fusion_kcat_pretrained.pt"
print(f"正在从Hub下载权重: {MODEL_REPO_ID}/{MODEL_FILENAME}...")
try:
# 注意:此处需要 huggingface_hub > 0.22.0 才能正确处理pt文件下载
model_weights_path = os.path.join(MODEL_REPO_ID, MODEL_FILENAME)
if not os.path.exists(model_weights_path):
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, local_dir=".", local_dir_use_symlinks=False)
model_weights_path = MODEL_FILENAME
except Exception as e:
raise RuntimeError(f"从Hub下载模型权重失败: {e}")
model = DeepFusionKcatPredictor(
esm_model_name=ESM_MODEL_NAME, gnn_input_dim=18, gnn_hidden_dim=256, gnn_heads=4,
d_model=256, num_fusion_blocks=3, num_attn_heads=8, dim_feedforward=1024, dropout=0.1
).to(DEVICE)
print(f"正在加载模型权重...")
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
model.eval()
print(f"正在加载分词器...")
tokenizer = AutoTokenizer.from_pretrained(ESM_MODEL_NAME)
print(f"模型和分词器加载成功,运行在 {DEVICE} 设备上。")
# --- 3. 辅助函数 ---
def clean_protein_sequence(sequence):
sequence = re.sub(r'>.*\n', '', str(sequence))
return "".join(sequence.split())
# --- 4. Flask 路由 ---
@app.route('/')
def home():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
if 'data_file' not in request.files:
return jsonify({'error': '请求中未找到文件部分。'}), 400
file = request.files['data_file']
if file.filename == '':
return jsonify({'error': '未选择文件。'}), 400
filename = secure_filename(file.filename)
try:
if filename.endswith('.csv'):
df_full = pd.read_csv(file)
elif filename.endswith(('.xls', '.xlsx')):
df_full = pd.read_excel(file)
else:
return jsonify({'error': '不支持的文件类型。请上传 .csv 或 .xlsx 文件。'}), 400
required_columns = {'protein_sequence', 'substrate_smiles'}
if not required_columns.issubset(df_full.columns):
return jsonify({'error': f'文件缺少必需的列。请确保文件包含以下列: {list(required_columns)}'}), 400
df = df_full[list(required_columns)].copy()
protein_seqs = df['protein_sequence'].apply(clean_protein_sequence).tolist()
smiles_list = df['substrate_smiles'].astype(str).tolist()
inputs = tokenizer(protein_seqs, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(DEVICE)
with torch.no_grad():
log_kcat_preds = model(inputs['input_ids'], inputs['attention_mask'], smiles_list)
log_kcat_list = log_kcat_preds.cpu().numpy().tolist()
kcat_list = [10**log_k for log_k in log_kcat_list]
df_full['predicted_log10_kcat'] = [f"{v:.4f}" for v in log_kcat_list]
df_full['predicted_kcat_s_neg1'] = [f"{v:.4f}" for v in kcat_list]
# --- <<< 以下是修改和添加的日志记录部分 >>> ---
consent_given = request.form.get('consent_given') == 'true'
if HF_TOKEN and consent_given:
print("用户同意记录数据,正在准备上传日志...")
try:
# 1. 准备要记录的数据
log_data = {
'timestamp_utc': [datetime.utcnow().isoformat()] * len(protein_seqs),
'protein_sequence': protein_seqs,
'substrate_smiles': smiles_list,
'predicted_log10_kcat': [val[0] for val in log_kcat_list] # 确保是一维列表
}
log_df = pd.DataFrame(log_data)
# 2. 将数据转换为内存中的CSV文件
log_buffer = io.StringIO()
log_df.to_csv(log_buffer, index=False)
log_bytes = log_buffer.getvalue().encode("utf-8")
# 3. 创建一个唯一的文件名并上传
timestamp_str = datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S-%f')
log_filename_in_repo = f"logs/log_{timestamp_str}.csv"
api.upload_file(
path_or_fileobj=log_bytes,
path_in_repo=log_filename_in_repo,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message=f"Log data from {timestamp_str}"
)
print(f"成功将日志上传到: {DATASET_REPO_ID}/{log_filename_in_repo}")
except Exception as e:
# 即使日志记录失败,也不影响主程序流程
print(f"【警告】数据日志记录失败: {e}")
# --- <<< 日志记录部分结束 >>> ---
# 准备并返回结果文件
output_buffer = io.BytesIO()
output_filename = f"predictions_{os.path.splitext(filename)[0]}.csv"
df_full.to_csv(output_buffer, index=False)
output_buffer.seek(0)
return Response(
output_buffer.getvalue(),
mimetype="text/csv",
headers={"Content-Disposition": f"attachment; filename=\"{output_filename}\""}
)
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': f'服务器处理文件时发生内部错误: {str(e)}'}), 500
# --- 5. 启动应用 ---
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)