|
import os |
|
import re |
|
import io |
|
import pandas as pd |
|
from datetime import datetime |
|
import torch |
|
from flask import Flask, request, jsonify, render_template, Response |
|
from transformers import AutoTokenizer |
|
from huggingface_hub import HfApi, HfFolder |
|
from werkzeug.utils import secure_filename |
|
|
|
from model import DeepFusionKcatPredictor |
|
|
|
|
|
app = Flask(__name__) |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
ESM_MODEL_NAME = "facebook/esm2_t33_650M_UR50D" |
|
DATASET_REPO_ID = "KangjieXu/TransKP-usage-logs" |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
if HF_TOKEN: |
|
HfFolder.save_token(HF_TOKEN) |
|
api = HfApi() |
|
print("Hugging Face Hub API客户端初始化成功。") |
|
|
|
|
|
MODEL_REPO_ID = "KangjieXu/TransKP-model" |
|
MODEL_FILENAME = "deep_fusion_kcat_pretrained.pt" |
|
print(f"正在从Hub下载权重: {MODEL_REPO_ID}/{MODEL_FILENAME}...") |
|
try: |
|
|
|
model_weights_path = os.path.join(MODEL_REPO_ID, MODEL_FILENAME) |
|
if not os.path.exists(model_weights_path): |
|
from huggingface_hub import hf_hub_download |
|
hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, local_dir=".", local_dir_use_symlinks=False) |
|
model_weights_path = MODEL_FILENAME |
|
except Exception as e: |
|
raise RuntimeError(f"从Hub下载模型权重失败: {e}") |
|
|
|
model = DeepFusionKcatPredictor( |
|
esm_model_name=ESM_MODEL_NAME, gnn_input_dim=18, gnn_hidden_dim=256, gnn_heads=4, |
|
d_model=256, num_fusion_blocks=3, num_attn_heads=8, dim_feedforward=1024, dropout=0.1 |
|
).to(DEVICE) |
|
|
|
print(f"正在加载模型权重...") |
|
model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE)) |
|
model.eval() |
|
|
|
print(f"正在加载分词器...") |
|
tokenizer = AutoTokenizer.from_pretrained(ESM_MODEL_NAME) |
|
print(f"模型和分词器加载成功,运行在 {DEVICE} 设备上。") |
|
|
|
|
|
def clean_protein_sequence(sequence): |
|
sequence = re.sub(r'>.*\n', '', str(sequence)) |
|
return "".join(sequence.split()) |
|
|
|
|
|
@app.route('/') |
|
def home(): |
|
return render_template('index.html') |
|
|
|
@app.route('/predict', methods=['POST']) |
|
def predict(): |
|
if 'data_file' not in request.files: |
|
return jsonify({'error': '请求中未找到文件部分。'}), 400 |
|
|
|
file = request.files['data_file'] |
|
if file.filename == '': |
|
return jsonify({'error': '未选择文件。'}), 400 |
|
|
|
filename = secure_filename(file.filename) |
|
|
|
try: |
|
if filename.endswith('.csv'): |
|
df_full = pd.read_csv(file) |
|
elif filename.endswith(('.xls', '.xlsx')): |
|
df_full = pd.read_excel(file) |
|
else: |
|
return jsonify({'error': '不支持的文件类型。请上传 .csv 或 .xlsx 文件。'}), 400 |
|
|
|
required_columns = {'protein_sequence', 'substrate_smiles'} |
|
if not required_columns.issubset(df_full.columns): |
|
return jsonify({'error': f'文件缺少必需的列。请确保文件包含以下列: {list(required_columns)}'}), 400 |
|
|
|
df = df_full[list(required_columns)].copy() |
|
|
|
protein_seqs = df['protein_sequence'].apply(clean_protein_sequence).tolist() |
|
smiles_list = df['substrate_smiles'].astype(str).tolist() |
|
|
|
inputs = tokenizer(protein_seqs, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
log_kcat_preds = model(inputs['input_ids'], inputs['attention_mask'], smiles_list) |
|
|
|
log_kcat_list = log_kcat_preds.cpu().numpy().tolist() |
|
kcat_list = [10**log_k for log_k in log_kcat_list] |
|
|
|
df_full['predicted_log10_kcat'] = [f"{v:.4f}" for v in log_kcat_list] |
|
df_full['predicted_kcat_s_neg1'] = [f"{v:.4f}" for v in kcat_list] |
|
|
|
|
|
consent_given = request.form.get('consent_given') == 'true' |
|
if HF_TOKEN and consent_given: |
|
print("用户同意记录数据,正在准备上传日志...") |
|
try: |
|
|
|
log_data = { |
|
'timestamp_utc': [datetime.utcnow().isoformat()] * len(protein_seqs), |
|
'protein_sequence': protein_seqs, |
|
'substrate_smiles': smiles_list, |
|
'predicted_log10_kcat': [val[0] for val in log_kcat_list] |
|
} |
|
log_df = pd.DataFrame(log_data) |
|
|
|
|
|
log_buffer = io.StringIO() |
|
log_df.to_csv(log_buffer, index=False) |
|
log_bytes = log_buffer.getvalue().encode("utf-8") |
|
|
|
|
|
timestamp_str = datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S-%f') |
|
log_filename_in_repo = f"logs/log_{timestamp_str}.csv" |
|
|
|
api.upload_file( |
|
path_or_fileobj=log_bytes, |
|
path_in_repo=log_filename_in_repo, |
|
repo_id=DATASET_REPO_ID, |
|
repo_type="dataset", |
|
commit_message=f"Log data from {timestamp_str}" |
|
) |
|
print(f"成功将日志上传到: {DATASET_REPO_ID}/{log_filename_in_repo}") |
|
|
|
except Exception as e: |
|
|
|
print(f"【警告】数据日志记录失败: {e}") |
|
|
|
|
|
|
|
output_buffer = io.BytesIO() |
|
output_filename = f"predictions_{os.path.splitext(filename)[0]}.csv" |
|
df_full.to_csv(output_buffer, index=False) |
|
output_buffer.seek(0) |
|
|
|
return Response( |
|
output_buffer.getvalue(), |
|
mimetype="text/csv", |
|
headers={"Content-Disposition": f"attachment; filename=\"{output_filename}\""} |
|
) |
|
|
|
except Exception as e: |
|
import traceback |
|
traceback.print_exc() |
|
return jsonify({'error': f'服务器处理文件时发生内部错误: {str(e)}'}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |