Spaces:
Sleeping
Sleeping
Upload 13 files
Browse files- antioxidant_predictor_5.py +100 -0
- app.py +364 -0
- checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth +3 -0
- checkpoints/scaler_FINETUNED_PROTT5.pkl +3 -0
- feature_extract.py +341 -0
- generator_checkpoints_v3.6/final_generator_model.pth +3 -0
- prott5/model/config.json +25 -0
- prott5/model/finetuned_prott5.bin +3 -0
- prott5/model/gitattributes +16 -0
- prott5/model/special_tokens_map.json +1 -0
- prott5/model/spiece.model +3 -0
- prott5/model/tokenizer_config.json +1 -0
- requirements.txt +172 -0
antioxidant_predictor_5.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
class AntioxidantPredictor(nn.Module):
|
8 |
+
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
9 |
+
super(AntioxidantPredictor, self).__init__()
|
10 |
+
self.prott5_dim = 1024
|
11 |
+
self.handcrafted_dim = input_dim - self.prott5_dim
|
12 |
+
self.seq_len = 16
|
13 |
+
self.prott5_feature_dim = 64 # 16 * 64 = 1024
|
14 |
+
|
15 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
16 |
+
d_model=self.prott5_feature_dim,
|
17 |
+
nhead=transformer_heads,
|
18 |
+
dropout=transformer_dropout,
|
19 |
+
batch_first=True
|
20 |
+
)
|
21 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
22 |
+
|
23 |
+
fused_dim = self.prott5_feature_dim + self.handcrafted_dim
|
24 |
+
self.fusion_fc = nn.Sequential(
|
25 |
+
nn.Linear(fused_dim, 1024),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Dropout(0.3),
|
28 |
+
nn.Linear(1024, 512),
|
29 |
+
nn.ReLU(),
|
30 |
+
nn.Dropout(0.3)
|
31 |
+
)
|
32 |
+
|
33 |
+
self.classifier = nn.Sequential(
|
34 |
+
nn.Linear(512, 256),
|
35 |
+
nn.ReLU(),
|
36 |
+
nn.Dropout(0.3),
|
37 |
+
nn.Linear(256, 1)
|
38 |
+
)
|
39 |
+
|
40 |
+
# 温度缩放参数 T
|
41 |
+
# 初始化为1.0,表示在校准前不改变logits
|
42 |
+
# requires_grad=False,因为T通常在模型训练完成后单独优化
|
43 |
+
self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
|
44 |
+
|
45 |
+
def forward(self, x, *args):
|
46 |
+
batch_size = x.size(0)
|
47 |
+
prot_t5_features = x[:, :self.prott5_dim]
|
48 |
+
handcrafted_features = x[:, self.prott5_dim:]
|
49 |
+
|
50 |
+
prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
|
51 |
+
encoded_seq = self.transformer_encoder(prot_t5_seq)
|
52 |
+
refined_prott5 = encoded_seq.mean(dim=1)
|
53 |
+
|
54 |
+
fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
|
55 |
+
fused_features = self.fusion_fc(fused_features)
|
56 |
+
|
57 |
+
logits = self.classifier(fused_features)
|
58 |
+
|
59 |
+
# 应用温度缩放: logits / T
|
60 |
+
# 注意:这里是在获取原始logits后,外部应用sigmoid前进行缩放
|
61 |
+
# 如果要直接输出校准后的概率,可以在这里除以T然后sigmoid
|
62 |
+
# 但通常T的优化和应用是分离的。
|
63 |
+
# 为了在调用模型时就能获得校准的logits(如果T已优化),我们在这里应用它。
|
64 |
+
# 如果T未被优化(仍为1),则此操作无影响。
|
65 |
+
logits_scaled = logits / self.temperature
|
66 |
+
|
67 |
+
return logits_scaled # 返回校准后(或原始,如果T=1)的logits
|
68 |
+
|
69 |
+
def set_temperature(self, temp_value, device):
|
70 |
+
"""用于设置优化后的温度值"""
|
71 |
+
self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
|
72 |
+
print(f"模型温度 T 设置为: {self.temperature.item()}")
|
73 |
+
|
74 |
+
def get_temperature(self):
|
75 |
+
"""获取当前温度值"""
|
76 |
+
return self.temperature.item()
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
dummy_input = torch.randn(8, 1914)
|
80 |
+
model = AntioxidantPredictor(input_dim=1914)
|
81 |
+
|
82 |
+
print(f"初始温度: {model.get_temperature()}")
|
83 |
+
logits_output_initial = model(dummy_input)
|
84 |
+
print("初始 logits shape:", logits_output_initial.shape)
|
85 |
+
probs_initial = torch.sigmoid(logits_output_initial)
|
86 |
+
print("初始概率 (T=1.0):", probs_initial.detach().cpu().numpy()[:2])
|
87 |
+
|
88 |
+
# 模拟设置一个优化后的温度
|
89 |
+
model.set_temperature(1.5, device='cpu') # 假设优化得到 T=1.5
|
90 |
+
print(f"设置后温度: {model.get_temperature()}")
|
91 |
+
logits_output_scaled = model(dummy_input) # 模型内部应用了 T
|
92 |
+
print("缩放后 logits shape:", logits_output_scaled.shape)
|
93 |
+
probs_scaled = torch.sigmoid(logits_output_scaled) # 外部仍然需要 sigmoid
|
94 |
+
print("缩放后概率 (T=1.5):", probs_scaled.detach().cpu().numpy()[:2])
|
95 |
+
|
96 |
+
# 验证 logits / T 的效果
|
97 |
+
# logits_manual_scale = logits_output_initial / 1.5
|
98 |
+
# probs_manual_scale = torch.sigmoid(logits_manual_scale)
|
99 |
+
# print("手动缩放后概率 (T=1.5):", probs_manual_scale.detach().cpu().numpy()[:2])
|
100 |
+
# assert torch.allclose(probs_scaled, probs_manual_scale) # 应该相等
|
app.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py - RLAnOxPeptide Gradio Web Application
|
2 |
+
# This script integrates both the predictor and generator into a user-friendly web UI.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import pandas as pd
|
8 |
+
import joblib
|
9 |
+
import numpy as np
|
10 |
+
import gradio as gr
|
11 |
+
from sklearn.cluster import KMeans
|
12 |
+
from tqdm import tqdm
|
13 |
+
import transformers
|
14 |
+
|
15 |
+
# Suppress verbose logging from transformers
|
16 |
+
transformers.logging.set_verbosity_error()
|
17 |
+
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
# SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
|
20 |
+
# To make this app self-contained, we copy necessary class definitions here.
|
21 |
+
# These should match the versions used during training.
|
22 |
+
# --------------------------------------------------------------------------
|
23 |
+
|
24 |
+
# --- Vocabulary Definition (from both scripts) ---
|
25 |
+
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
26 |
+
token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
|
27 |
+
token2id["<PAD>"] = 0
|
28 |
+
token2id["<EOS>"] = 1
|
29 |
+
id2token = {i: t for t, i in token2id.items()}
|
30 |
+
VOCAB_SIZE = len(token2id)
|
31 |
+
|
32 |
+
# --- Predictor Model Architecture (from antioxidant_predictor_5.py) ---
|
33 |
+
class AntioxidantPredictor(nn.Module):
|
34 |
+
# This class definition should be an exact copy from your project
|
35 |
+
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
36 |
+
super(AntioxidantPredictor, self).__init__()
|
37 |
+
self.input_dim = input_dim
|
38 |
+
self.t5_dim = 1024
|
39 |
+
self.hand_crafted_dim = self.input_dim - self.t5_dim
|
40 |
+
|
41 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
42 |
+
d_model=self.t5_dim, nhead=transformer_heads,
|
43 |
+
dropout=transformer_dropout, batch_first=True
|
44 |
+
)
|
45 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
46 |
+
|
47 |
+
self.mlp = nn.Sequential(
|
48 |
+
nn.Linear(self.input_dim, 512),
|
49 |
+
nn.ReLU(),
|
50 |
+
nn.Dropout(0.5),
|
51 |
+
nn.Linear(512, 256),
|
52 |
+
nn.ReLU(),
|
53 |
+
nn.Dropout(0.5),
|
54 |
+
nn.Linear(256, 1)
|
55 |
+
)
|
56 |
+
self.temperature = nn.Parameter(torch.ones(1))
|
57 |
+
|
58 |
+
def forward(self, fused_features):
|
59 |
+
tr_features = fused_features[:, :self.t5_dim]
|
60 |
+
hand_features = fused_features[:, self.t5_dim:]
|
61 |
+
tr_features_unsqueezed = tr_features.unsqueeze(1)
|
62 |
+
transformer_output = self.transformer_encoder(tr_features_unsqueezed)
|
63 |
+
transformer_output_pooled = transformer_output.mean(dim=1)
|
64 |
+
combined_features = torch.cat((transformer_output_pooled, hand_features), dim=1)
|
65 |
+
logits = self.mlp(combined_features)
|
66 |
+
return logits / self.temperature
|
67 |
+
|
68 |
+
def get_temperature(self):
|
69 |
+
return self.temperature.item()
|
70 |
+
|
71 |
+
# --- Generator Model Architecture (from generator.py) ---
|
72 |
+
class ProtT5Generator(nn.Module):
|
73 |
+
# This class definition should be an exact copy from your project
|
74 |
+
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
75 |
+
super(ProtT5Generator, self).__init__()
|
76 |
+
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=token2id["<PAD>"])
|
77 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True)
|
78 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
79 |
+
self.lm_head = nn.Linear(embed_dim, vocab_size)
|
80 |
+
self.vocab_size = vocab_size
|
81 |
+
self.eos_token_id = token2id["<EOS>"]
|
82 |
+
self.pad_token_id = token2id["<PAD>"]
|
83 |
+
|
84 |
+
def forward(self, input_ids):
|
85 |
+
embeddings = self.embed_tokens(input_ids)
|
86 |
+
encoder_output = self.encoder(embeddings)
|
87 |
+
logits = self.lm_head(encoder_output)
|
88 |
+
return logits
|
89 |
+
|
90 |
+
def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
|
91 |
+
start_token = torch.randint(2, self.vocab_size, (batch_size, 1), device=device)
|
92 |
+
generated = start_token
|
93 |
+
for _ in range(max_length - 1):
|
94 |
+
logits = self.forward(generated)
|
95 |
+
next_logits = logits[:, -1, :] / temperature
|
96 |
+
if generated.size(1) < min_decoded_length:
|
97 |
+
next_logits[:, self.eos_token_id] = -float("inf")
|
98 |
+
|
99 |
+
probs = torch.softmax(next_logits, dim=-1)
|
100 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
101 |
+
generated = torch.cat((generated, next_token), dim=1)
|
102 |
+
|
103 |
+
if (next_token == self.eos_token_id).all():
|
104 |
+
break
|
105 |
+
return generated
|
106 |
+
|
107 |
+
def decode(self, token_ids_batch):
|
108 |
+
seqs = []
|
109 |
+
for ids_tensor in token_ids_batch:
|
110 |
+
seq = ""
|
111 |
+
# Start from index 1 to skip the initial random start token
|
112 |
+
for token_id in ids_tensor.tolist()[1:]:
|
113 |
+
if token_id == self.eos_token_id: break
|
114 |
+
if token_id == self.pad_token_id: continue
|
115 |
+
seq += id2token.get(token_id, "?")
|
116 |
+
seqs.append(seq)
|
117 |
+
return seqs
|
118 |
+
|
119 |
+
# --- Feature Extraction Logic (from feature_extract.py) ---
|
120 |
+
# Note: You need the actual ProtT5Model and extract_features here.
|
121 |
+
# Assuming they are in a file named `feature_extract.py` in the same directory.
|
122 |
+
try:
|
123 |
+
from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
|
124 |
+
except ImportError:
|
125 |
+
raise gr.Error("Failed to import feature_extract.py. Please ensure the file is in the same directory as app.py.")
|
126 |
+
|
127 |
+
# --- Clustering Logic (from generator.py) ---
|
128 |
+
def cluster_sequences(generator, sequences, num_clusters, device):
|
129 |
+
if not sequences or len(sequences) < num_clusters:
|
130 |
+
return sequences[:num_clusters]
|
131 |
+
with torch.no_grad():
|
132 |
+
token_ids_list = []
|
133 |
+
max_len = max(len(seq) for seq in sequences) + 2 # Start token + EOS
|
134 |
+
for seq in sequences:
|
135 |
+
# Recreate encoding to match how generator sees it (with start token)
|
136 |
+
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
137 |
+
ids = [np.random.randint(2, VOCAB_SIZE)] + ids # Add a dummy start token
|
138 |
+
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
139 |
+
token_ids_list.append(ids)
|
140 |
+
|
141 |
+
input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
|
142 |
+
embeddings = generator.embed_tokens(input_ids)
|
143 |
+
mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
|
144 |
+
embeddings = embeddings * mask
|
145 |
+
lengths = mask.sum(dim=1)
|
146 |
+
seq_embeds = embeddings.sum(dim=1) / (lengths + 1e-9)
|
147 |
+
seq_embeds_np = seq_embeds.cpu().numpy()
|
148 |
+
|
149 |
+
kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
|
150 |
+
representatives = []
|
151 |
+
for i in range(int(num_clusters)):
|
152 |
+
indices = np.where(kmeans.labels_ == i)[0]
|
153 |
+
if len(indices) == 0: continue
|
154 |
+
cluster_center = kmeans.cluster_centers_[i]
|
155 |
+
cluster_embeddings = seq_embeds_np[indices]
|
156 |
+
distances = np.linalg.norm(cluster_embeddings - cluster_center, axis=1)
|
157 |
+
representative_index = indices[np.argmin(distances)]
|
158 |
+
representatives.append(sequences[representative_index])
|
159 |
+
return representatives
|
160 |
+
|
161 |
+
|
162 |
+
# --------------------------------------------------------------------------
|
163 |
+
# SECTION 2: GLOBAL MODEL LOADING
|
164 |
+
# Load all models and dependencies once when the app starts.
|
165 |
+
# --------------------------------------------------------------------------
|
166 |
+
print("Loading all models and dependencies. Please wait...")
|
167 |
+
DEVICE = "cpu" # Use CPU for compatibility with Hugging Face free tier
|
168 |
+
|
169 |
+
try:
|
170 |
+
# --- Define all required file paths here ---
|
171 |
+
# !! IMPORTANT: Ensure these are relative paths to the files in your Space !!
|
172 |
+
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
173 |
+
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
174 |
+
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
175 |
+
PROTT5_BASE_MODEL_PATH = "prott5/model/"
|
176 |
+
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
177 |
+
|
178 |
+
# --- Load Predictor Components ---
|
179 |
+
print("Loading Predictor Model...")
|
180 |
+
PREDICTOR_MODEL = AntioxidantPredictor(
|
181 |
+
input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
|
182 |
+
)
|
183 |
+
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
184 |
+
PREDICTOR_MODEL.to(DEVICE)
|
185 |
+
PREDICTOR_MODEL.eval()
|
186 |
+
print("✅ Predictor model loaded.")
|
187 |
+
|
188 |
+
print("Loading Scaler...")
|
189 |
+
SCALER = joblib.load(SCALER_PATH)
|
190 |
+
print("✅ Scaler loaded.")
|
191 |
+
|
192 |
+
print("Loading ProtT5 Feature Extractor...")
|
193 |
+
# This extractor must use the fine-tuned model for features, as per your training logic
|
194 |
+
PROTT5_EXTRACTOR = FeatureProtT5Model(
|
195 |
+
model_path=PROTT5_BASE_MODEL_PATH,
|
196 |
+
finetuned_model_file=FINETUNED_PROTT5_FOR_FEATURES_PATH
|
197 |
+
)
|
198 |
+
print("✅ ProtT5 Feature Extractor loaded.")
|
199 |
+
|
200 |
+
# --- Load Generator Model ---
|
201 |
+
print("Loading Generator Model...")
|
202 |
+
GENERATOR_MODEL = ProtT5Generator(
|
203 |
+
vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1
|
204 |
+
)
|
205 |
+
GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
|
206 |
+
GENERATOR_MODEL.to(DEVICE)
|
207 |
+
GENERATOR_MODEL.eval()
|
208 |
+
print("✅ Generator model loaded.")
|
209 |
+
|
210 |
+
print("\n--- All models loaded successfully! Gradio app is ready. ---\n")
|
211 |
+
|
212 |
+
except Exception as e:
|
213 |
+
print(f"💥 FATAL ERROR: Failed to load a model or dependency file: {e}")
|
214 |
+
raise gr.Error(f"Model or dependency loading failed! Check file paths and integrity. Error: {e}")
|
215 |
+
|
216 |
+
# --------------------------------------------------------------------------
|
217 |
+
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO
|
218 |
+
# These functions connect the UI to our model's logic.
|
219 |
+
# --------------------------------------------------------------------------
|
220 |
+
|
221 |
+
def predict_peptide_wrapper(sequence_str):
|
222 |
+
"""Takes a peptide sequence string and returns its predicted probability and class."""
|
223 |
+
if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
|
224 |
+
return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
|
225 |
+
|
226 |
+
try:
|
227 |
+
# 1. Extract features using the same logic as training/prediction scripts
|
228 |
+
features = extract_features(sequence_str, PROTT5_EXTRACTOR)
|
229 |
+
|
230 |
+
# 2. Scale features
|
231 |
+
scaled_features = SCALER.transform(features.reshape(1, -1))
|
232 |
+
|
233 |
+
# 3. Predict with the model
|
234 |
+
with torch.no_grad():
|
235 |
+
features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
|
236 |
+
logits = PREDICTOR_MODEL(features_tensor)
|
237 |
+
probability = torch.sigmoid(logits).squeeze().item()
|
238 |
+
|
239 |
+
classification = "Antioxidant" if probability >= 0.5 else "Non-Antioxidant"
|
240 |
+
return f"{probability:.4f}", classification
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
print(f"Prediction error for sequence '{sequence_str}': {e}")
|
244 |
+
return "N/A", f"An error occurred during processing: {e}"
|
245 |
+
|
246 |
+
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
|
247 |
+
"""Generates, validates, and clusters sequences."""
|
248 |
+
num_to_generate = int(num_to_generate)
|
249 |
+
min_len = int(min_len)
|
250 |
+
max_len = int(max_len)
|
251 |
+
|
252 |
+
try:
|
253 |
+
# STEP 1: Generate an initial pool of unique sequences
|
254 |
+
target_pool_size = int(num_to_generate * diversity_factor)
|
255 |
+
unique_seqs = set()
|
256 |
+
progress(0, desc="Generating initial peptide pool...")
|
257 |
+
|
258 |
+
max_attempts = 10
|
259 |
+
attempts = 0
|
260 |
+
while len(unique_seqs) < target_pool_size and attempts < max_attempts:
|
261 |
+
batch_size = (target_pool_size - len(unique_seqs)) * 2 # Generate extra to account for duplicates/short ones
|
262 |
+
with torch.no_grad():
|
263 |
+
generated_tokens = GENERATOR_MODEL.sample(
|
264 |
+
batch_size=batch_size,
|
265 |
+
max_length=max_len,
|
266 |
+
device=DEVICE,
|
267 |
+
temperature=temperature,
|
268 |
+
min_decoded_length=min_len
|
269 |
+
)
|
270 |
+
decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
|
271 |
+
for seq in decoded:
|
272 |
+
if min_len <= len(seq) <= max_len:
|
273 |
+
unique_seqs.add(seq)
|
274 |
+
attempts += 1
|
275 |
+
progress(len(unique_seqs) / target_pool_size, desc=f"Generated {len(unique_seqs)} unique sequences...")
|
276 |
+
|
277 |
+
candidate_seqs = list(unique_seqs)
|
278 |
+
if not candidate_seqs:
|
279 |
+
return pd.DataFrame({"Sequence": ["Failed to generate valid sequences."], "Predicted Probability": ["N/A"]})
|
280 |
+
|
281 |
+
# STEP 2: Validate the generated sequences
|
282 |
+
validated_pool = {}
|
283 |
+
for seq in tqdm(candidate_seqs, desc="Validating generated sequences"):
|
284 |
+
prob_str, _ = predict_peptide_wrapper(seq)
|
285 |
+
try:
|
286 |
+
prob = float(prob_str)
|
287 |
+
if prob > 0.90: # Filter for high-quality peptides as in generator.py
|
288 |
+
validated_pool[seq] = prob
|
289 |
+
except (ValueError, TypeError):
|
290 |
+
continue
|
291 |
+
|
292 |
+
if not validated_pool:
|
293 |
+
return pd.DataFrame({"Sequence": ["No high-activity peptides (>0.9 prob) were generated."], "Predicted Probability": ["N/A"]})
|
294 |
+
|
295 |
+
high_quality_sequences = list(validated_pool.keys())
|
296 |
+
|
297 |
+
# STEP 3: Cluster to ensure diversity
|
298 |
+
progress(1.0, desc="Clustering for diversity...")
|
299 |
+
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|
300 |
+
|
301 |
+
# STEP 4: Format final results
|
302 |
+
final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
|
303 |
+
final_results.sort(key=lambda x: float(x[1]), reverse=True)
|
304 |
+
|
305 |
+
return pd.DataFrame(final_results, columns=["Sequence", "Predicted Probability"])
|
306 |
+
|
307 |
+
except Exception as e:
|
308 |
+
print(f"Generation error: {e}")
|
309 |
+
return pd.DataFrame({"Sequence": [f"An error occurred during generation: {e}"], "Predicted Probability": ["N/A"]})
|
310 |
+
|
311 |
+
|
312 |
+
# --------------------------------------------------------------------------
|
313 |
+
# SECTION 4: GRADIO UI CONSTRUCTION
|
314 |
+
# Building the web interface. All text is in English.
|
315 |
+
# --------------------------------------------------------------------------
|
316 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
317 |
+
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
|
318 |
+
gr.Markdown("An integrated framework combining reinforcement learning and a Transformer model for the efficient prediction and innovative design of antioxidant peptides.")
|
319 |
+
|
320 |
+
with gr.Tabs():
|
321 |
+
with gr.TabItem("Peptide Activity Predictor"):
|
322 |
+
gr.Markdown("### Enter an amino acid sequence to predict its antioxidant activity.")
|
323 |
+
with gr.Row():
|
324 |
+
peptide_input = gr.Textbox(label="Peptide Sequence", placeholder="e.g., WHYHDYKY", scale=3)
|
325 |
+
predict_button = gr.Button("Predict", variant="primary", scale=1)
|
326 |
+
with gr.Row():
|
327 |
+
probability_output = gr.Textbox(label="Predicted Probability")
|
328 |
+
class_output = gr.Textbox(label="Predicted Class")
|
329 |
+
|
330 |
+
predict_button.click(
|
331 |
+
fn=predict_peptide_wrapper,
|
332 |
+
inputs=peptide_input,
|
333 |
+
outputs=[probability_output, class_output]
|
334 |
+
)
|
335 |
+
gr.Examples(
|
336 |
+
examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"], ["INVALIDSEQUENCE"]],
|
337 |
+
inputs=peptide_input,
|
338 |
+
outputs=[probability_output, class_output],
|
339 |
+
fn=predict_peptide_wrapper,
|
340 |
+
cache_examples=False,
|
341 |
+
)
|
342 |
+
|
343 |
+
with gr.TabItem("Novel Sequence Generator"):
|
344 |
+
gr.Markdown("### Set parameters to generate novel, high-activity antioxidant peptides.")
|
345 |
+
with gr.Column():
|
346 |
+
with gr.Row():
|
347 |
+
num_input = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
|
348 |
+
min_len_input = gr.Slider(minimum=2, maximum=10, value=3, step=1, label="Minimum Length")
|
349 |
+
max_len_input = gr.Slider(minimum=10, maximum=20, value=20, step=1, label="Maximum Length")
|
350 |
+
with gr.Row():
|
351 |
+
temp_input = gr.Slider(minimum=0.5, maximum=3.0, value=2.5, step=0.1, label="Temperature (Higher = More random)")
|
352 |
+
diversity_input = gr.Slider(minimum=1.0, maximum=3.0, value=1.2, step=0.1, label="Diversity Factor (Higher = Larger initial pool for clustering)")
|
353 |
+
|
354 |
+
generate_button = gr.Button("Generate Peptides", variant="primary")
|
355 |
+
results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides", wrap=True)
|
356 |
+
|
357 |
+
generate_button.click(
|
358 |
+
fn=generate_peptide_wrapper,
|
359 |
+
inputs=[num_input, min_len_input, max_len_input, temp_input, diversity_input],
|
360 |
+
outputs=results_output
|
361 |
+
)
|
362 |
+
|
363 |
+
if __name__ == "__main__":
|
364 |
+
demo.launch()
|
checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1cb2c7cf6a59a5d028af4678ac5aec83305d3638bf9f43fee6df14f162f4a4e6
|
3 |
+
size 9930081
|
checkpoints/scaler_FINETUNED_PROTT5.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50cfe7b0204a4e1dd5f22dc1440d6e8448f7ee323bb20957c70ce1058ca78475
|
3 |
+
size 31127
|
feature_extract.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import pandas as pd
|
7 |
+
from Bio.SeqUtils.ProtParam import ProteinAnalysis
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
# from sklearn.preprocessing import StandardScaler # 不再使用 StandardScaler
|
10 |
+
from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
|
11 |
+
import torch
|
12 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
13 |
+
|
14 |
+
# ProtT5Model, load_fasta, load_fasta_with_labels,
|
15 |
+
# compute_amino_acid_composition, compute_reducing_aa_ratio,
|
16 |
+
# compute_physicochemical_properties, compute_electronic_features,
|
17 |
+
# compute_dimer_frequency, positional_encoding, perturb_sequence,
|
18 |
+
# generate_adversarial_samples, extract_features 函数与您之前提供的版本相同。
|
19 |
+
# 为保持简洁,此处省略这些函数的代码。请确保它们在您的文件中是完整的。
|
20 |
+
# 您可以从之前的日志或您本地的文件中复制这些函数。
|
21 |
+
# 以下是 prepare_features 函数的修改版,以及其他函数的占位符。
|
22 |
+
|
23 |
+
class ProtT5Model:
|
24 |
+
"""
|
25 |
+
从本地加载 ProtT5 模型。如果 finetuned_model_file 不为空,则加载微调后的权重(使用 strict=False)。
|
26 |
+
"""
|
27 |
+
def __init__(self, model_path, finetuned_model_file=None):
|
28 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
# 尝试加载本地文件,如果失败,transformers库可能会尝试从hub下载(取决于配置)
|
30 |
+
try:
|
31 |
+
self.tokenizer = T5Tokenizer.from_pretrained(model_path, do_lower_case=False, local_files_only=True)
|
32 |
+
self.model = T5EncoderModel.from_pretrained(model_path, local_files_only=True)
|
33 |
+
except OSError: # OSError: Can't load tokenizer for '...'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure '...' is the correct path to a directory containing all relevant files for a T5Tokenizer tokenizer.
|
34 |
+
print(f"警告: 无法从本地路径 {model_path} 加载ProtT5模型/分词器。尝试从HuggingFace Hub下载(如果transformers配置允许)。")
|
35 |
+
self.tokenizer = T5Tokenizer.from_pretrained(model_path.split('/')[-1] if '/' in model_path else model_path, do_lower_case=False) # 尝试使用模型名下载
|
36 |
+
self.model = T5EncoderModel.from_pretrained(model_path.split('/')[-1] if '/' in model_path else model_path)
|
37 |
+
|
38 |
+
|
39 |
+
if finetuned_model_file is not None and os.path.exists(finetuned_model_file):
|
40 |
+
try:
|
41 |
+
state_dict = torch.load(finetuned_model_file, map_location=self.device)
|
42 |
+
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
|
43 |
+
print(f"加载微调权重 {finetuned_model_file}:缺失键 {missing_keys}, 意外键 {unexpected_keys}")
|
44 |
+
except Exception as e:
|
45 |
+
print(f"加载微调权重 {finetuned_model_file} 失败: {e}")
|
46 |
+
|
47 |
+
self.model.to(self.device)
|
48 |
+
self.model.eval()
|
49 |
+
|
50 |
+
def encode(self, sequence):
|
51 |
+
if not sequence or not isinstance(sequence, str): # 增加对空序列或非字符串的检查
|
52 |
+
print(f"警告: ProtT5Model.encode 接收到无效序列: {sequence}")
|
53 |
+
# 返回一个零向量或根据需要处理错误
|
54 |
+
# 假设 ProtT5 输出维度为 1024 (embedding.shape[1])
|
55 |
+
# 假设序列处理后平均池化,所以返回 (1024,)
|
56 |
+
# 但 encode 返回的是 (seq_len, hidden_dim),所以这里返回一个模拟的短序列零嵌入
|
57 |
+
return np.zeros((1, 1024), dtype=np.float32) # (1, hidden_dim)
|
58 |
+
|
59 |
+
seq_spaced = " ".join(list(sequence)) # 修改变量名以避免覆盖外部seq
|
60 |
+
try:
|
61 |
+
encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True, max_length=1022) # ProtT5通常最大长度1024,tokenized后可能更长
|
62 |
+
except Exception as e:
|
63 |
+
print(f"分词失败序列 '{sequence[:30]}...': {e}")
|
64 |
+
return np.zeros((1, 1024), dtype=np.float32)
|
65 |
+
|
66 |
+
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
67 |
+
with torch.no_grad():
|
68 |
+
try:
|
69 |
+
embedding = self.model(**encoded_input).last_hidden_state # (batch_size, seq_len, hidden_dim)
|
70 |
+
except Exception as e:
|
71 |
+
print(f"ProtT5模型推理失败序列 '{sequence[:30]}...': {e}")
|
72 |
+
return np.zeros((1, 1024), dtype=np.float32)
|
73 |
+
|
74 |
+
emb = embedding.squeeze(0).cpu().numpy() # (seq_len, hidden_dim)
|
75 |
+
if emb.shape[0] == 0: # 如果由于某种原因序列长度为0
|
76 |
+
return np.zeros((1, 1024), dtype=np.float32)
|
77 |
+
return emb
|
78 |
+
|
79 |
+
# --- (此处应包含您之前版本中所有其他的特征提取辅助函数) ---
|
80 |
+
# load_fasta, load_fasta_with_labels, compute_amino_acid_composition, ... extract_features
|
81 |
+
# 为确保完整性��请从您本地的 feature_extract.py 文件中复制这些函数到这里。
|
82 |
+
# 下面是这些函数的一个简化占位符,您需要用实际的函数替换它们。
|
83 |
+
|
84 |
+
def load_fasta(fasta_file):
|
85 |
+
# (您的 load_fasta 实现)
|
86 |
+
sequences = []
|
87 |
+
try:
|
88 |
+
with open(fasta_file, 'r') as f:
|
89 |
+
current_seq_lines = []
|
90 |
+
for line in f:
|
91 |
+
line = line.strip()
|
92 |
+
if not line: continue
|
93 |
+
if line.startswith(">"):
|
94 |
+
if current_seq_lines: sequences.append("".join(current_seq_lines))
|
95 |
+
current_seq_lines = []
|
96 |
+
else: current_seq_lines.append(line)
|
97 |
+
if current_seq_lines: sequences.append("".join(current_seq_lines))
|
98 |
+
except FileNotFoundError: print(f"文件未找到: {fasta_file}"); return []
|
99 |
+
return sequences
|
100 |
+
|
101 |
+
def load_fasta_with_labels(fasta_file):
|
102 |
+
# (您的 load_fasta_with_labels 实现)
|
103 |
+
sequences, labels = [], []
|
104 |
+
try:
|
105 |
+
with open(fasta_file, 'r') as f:
|
106 |
+
current_seq_lines, current_label = [], None
|
107 |
+
for line in f:
|
108 |
+
line = line.strip()
|
109 |
+
if not line: continue
|
110 |
+
if line.startswith(">"):
|
111 |
+
if current_seq_lines:
|
112 |
+
sequences.append("".join(current_seq_lines))
|
113 |
+
labels.append(current_label if current_label is not None else 0) # Default label 0
|
114 |
+
current_seq_lines = []
|
115 |
+
current_label = int(line[1]) if len(line) > 1 and line[1] in ['0', '1'] else 0
|
116 |
+
else: current_seq_lines.append(line)
|
117 |
+
if current_seq_lines:
|
118 |
+
sequences.append("".join(current_seq_lines))
|
119 |
+
labels.append(current_label if current_label is not None else 0)
|
120 |
+
except FileNotFoundError: print(f"文件未找到: {fasta_file}"); return [],[]
|
121 |
+
return sequences, labels
|
122 |
+
|
123 |
+
|
124 |
+
def compute_amino_acid_composition(seq):
|
125 |
+
if not seq: return {aa: 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
|
126 |
+
# (您的 compute_amino_acid_composition 实现)
|
127 |
+
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
128 |
+
seq_len = len(seq)
|
129 |
+
return {aa: seq.upper().count(aa) / seq_len for aa in amino_acids}
|
130 |
+
|
131 |
+
|
132 |
+
def compute_reducing_aa_ratio(seq):
|
133 |
+
if not seq: return 0.0
|
134 |
+
# (您的 compute_reducing_aa_ratio 实现)
|
135 |
+
reducing = ['C', 'M', 'W']
|
136 |
+
return sum(seq.upper().count(aa) for aa in reducing) / len(seq) if len(seq) > 0 else 0.0
|
137 |
+
|
138 |
+
def compute_physicochemical_properties(seq):
|
139 |
+
if not seq or not all(c.upper() in "ACDEFGHIKLMNPQRSTVWYXUBZ" for c in seq): # ProteinAnalysis might fail on invalid chars
|
140 |
+
return 0.0, 0.0, 0.0 # Default values
|
141 |
+
try:
|
142 |
+
analysis = ProteinAnalysis(str(seq).upper().replace('X','A').replace('U','C').replace('B','N').replace('Z','Q')) # Replace non-standard with common ones for analysis
|
143 |
+
return analysis.gravy(), analysis.isoelectric_point(), analysis.molecular_weight()
|
144 |
+
except Exception: # Catch any error from ProteinAnalysis
|
145 |
+
return 0.0, 7.0, 110.0 * len(seq) # Rough defaults
|
146 |
+
|
147 |
+
def compute_electronic_features(seq):
|
148 |
+
if not seq: return 0.0, 0.0
|
149 |
+
# (您的 compute_electronic_features 实现)
|
150 |
+
electronegativity = {'A':1.8,'C':2.5,'D':3.0,'E':3.2,'F':2.8,'G':1.6,'H':2.4,'I':4.5,'K':3.0,'L':4.2,'M':4.5,'N':2.0,'P':3.5,'Q':3.5,'R':2.5,'S':1.8,'T':2.5,'V':4.0,'W':5.0,'Y':4.0}
|
151 |
+
values = [electronegativity.get(aa.upper(), 2.5) for aa in seq]
|
152 |
+
avg_val = sum(values) / len(values) if values else 2.5
|
153 |
+
return avg_val + 0.1, avg_val - 0.1
|
154 |
+
|
155 |
+
|
156 |
+
def compute_dimer_frequency(seq):
|
157 |
+
if len(seq) < 2: return np.zeros(400) # 20*20
|
158 |
+
# (您的 compute_dimer_frequency 实现)
|
159 |
+
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
160 |
+
dimer_counts = {aa1+aa2: 0 for aa1 in amino_acids for aa2 in amino_acids}
|
161 |
+
for i in range(len(seq) - 1):
|
162 |
+
dimer = seq[i:i+2].upper()
|
163 |
+
if dimer in dimer_counts: dimer_counts[dimer] += 1
|
164 |
+
total = max(len(seq) - 1, 1)
|
165 |
+
for key in dimer_counts: dimer_counts[key] /= total
|
166 |
+
return np.array([dimer_counts[d] for d in sorted(dimer_counts.keys())])
|
167 |
+
|
168 |
+
|
169 |
+
def positional_encoding(seq_len_actual, L_fixed=29, d_model=16): # Pass actual sequence length or use L_fixed
|
170 |
+
# (您的 positional_encoding 实现)
|
171 |
+
# This PE is fixed length, not dependent on actual seq len if L_fixed is used.
|
172 |
+
# For random short sequences, this fixed PE might be an issue.
|
173 |
+
# A more dynamic PE or no PE for very short sequences might be better.
|
174 |
+
# However, to match current model input, we keep it.
|
175 |
+
pos_enc = np.zeros((L_fixed, d_model))
|
176 |
+
for pos in range(L_fixed):
|
177 |
+
for i in range(d_model):
|
178 |
+
angle = pos / (10000 ** (2 * (i // 2) / d_model))
|
179 |
+
pos_enc[pos, i] = np.sin(angle) if i % 2 == 0 else np.cos(angle)
|
180 |
+
return pos_enc.flatten()
|
181 |
+
|
182 |
+
|
183 |
+
def perturb_sequence(seq, perturb_rate=0.1, critical=['C', 'M', 'W']):
|
184 |
+
# (您的 perturb_sequence 实现)
|
185 |
+
if not seq: return ""
|
186 |
+
seq_list = list(seq)
|
187 |
+
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
188 |
+
for i, aa in enumerate(seq_list):
|
189 |
+
if aa.upper() not in critical and random.random() < perturb_rate:
|
190 |
+
seq_list[i] = random.choice([x for x in amino_acids if x != aa.upper()])
|
191 |
+
return "".join(seq_list)
|
192 |
+
|
193 |
+
|
194 |
+
def extract_features(seq, prott5_model_instance, L_fixed=29, d_model_pe=16): # Renamed d_model to d_model_pe
|
195 |
+
if not seq or not isinstance(seq, str) or len(seq) == 0:
|
196 |
+
print(f"警告: extract_features 接收到空或无效序列。返回零特征。")
|
197 |
+
# 返回一个与预期特征维度匹配的零向量
|
198 |
+
# 1024 (protT5) + 20 (aac) + 1 (red_ratio) + 3 (phys) + 2 (elec) + 400 (dimer) + L_fixed*d_model_pe (pos_enc)
|
199 |
+
# Example: 1024 + 20 + 1 + 3 + 2 + 400 + 29*16 = 1024 + 20 + 1 + 3 + 2 + 400 + 464 = 1914
|
200 |
+
return np.zeros(1024 + 20 + 1 + 3 + 2 + 400 + (L_fixed * d_model_pe))
|
201 |
+
|
202 |
+
|
203 |
+
embedding = prott5_model_instance.encode(seq) # prott5_model is now an instance
|
204 |
+
prot_embed = np.mean(embedding, axis=0) if embedding.shape[0] > 0 else np.zeros(embedding.shape[1] if embedding.ndim > 1 else 1024) # Handle empty embedding
|
205 |
+
if prot_embed.shape[0] != 1024: # Ensure consistent ProtT5 embedding dim
|
206 |
+
# print(f"警告: ProtT5 嵌入维度异常 ({prot_embed.shape[0]}) for seq '{seq[:20]}...'. 使用零向量。")
|
207 |
+
prot_embed = np.zeros(1024)
|
208 |
+
|
209 |
+
|
210 |
+
aa_comp = compute_amino_acid_composition(seq)
|
211 |
+
aa_comp_vector = np.array([aa_comp[aa] for aa in "ACDEFGHIKLMNPQRSTVWY"])
|
212 |
+
red_ratio = np.array([compute_reducing_aa_ratio(seq)])
|
213 |
+
gravy, pI, mol_weight = compute_physicochemical_properties(seq)
|
214 |
+
phys_props = np.array([gravy, pI, mol_weight])
|
215 |
+
HOMO, LUMO = compute_electronic_features(seq)
|
216 |
+
electronic = np.array([HOMO, LUMO])
|
217 |
+
dimer_vector = compute_dimer_frequency(seq)
|
218 |
+
pos_enc = positional_encoding(len(seq), L_fixed, d_model_pe) # Pass actual length, though current PE uses L_fixed
|
219 |
+
|
220 |
+
features = np.concatenate([prot_embed, aa_comp_vector, red_ratio, phys_props, electronic, dimer_vector, pos_enc])
|
221 |
+
return features
|
222 |
+
|
223 |
+
##############################################
|
224 |
+
# 主接口函数 prepare_features
|
225 |
+
##############################################
|
226 |
+
def prepare_features(neg_fasta, pos_fasta, prott5_model_path, additional_params=None):
|
227 |
+
neg_seqs = load_fasta(neg_fasta)
|
228 |
+
pos_seqs = load_fasta(pos_fasta)
|
229 |
+
|
230 |
+
if not neg_seqs and not pos_seqs:
|
231 |
+
raise ValueError("未能从FASTA文件加载任何序列。请检查文件路径和内容。")
|
232 |
+
|
233 |
+
neg_labels = [0] * len(neg_seqs)
|
234 |
+
pos_labels = [1] * len(pos_seqs)
|
235 |
+
sequences = neg_seqs + pos_seqs
|
236 |
+
labels = neg_labels + pos_labels
|
237 |
+
|
238 |
+
combined = list(zip(sequences, labels))
|
239 |
+
random.shuffle(combined)
|
240 |
+
sequences, labels = zip(*combined)
|
241 |
+
sequences = list(sequences)
|
242 |
+
labels = list(labels)
|
243 |
+
|
244 |
+
train_seqs, val_seqs, train_labels, val_labels = train_test_split(
|
245 |
+
sequences, labels, test_size=0.1, random_state=42, stratify=labels if len(np.unique(labels)) > 1 else None
|
246 |
+
)
|
247 |
+
print("训练集原始样本数:", len(train_seqs))
|
248 |
+
print("验证集原始样本数:", len(val_seqs))
|
249 |
+
|
250 |
+
if additional_params is not None and additional_params.get("augment", False):
|
251 |
+
# (数据增强逻辑 - 如果启用)
|
252 |
+
augmented_seqs, augmented_labels = [], []
|
253 |
+
perturb_rate = additional_params.get("perturb_rate", 0.1)
|
254 |
+
for seq, label in zip(train_seqs, train_labels):
|
255 |
+
aug_seq = perturb_sequence(seq, perturb_rate=perturb_rate)
|
256 |
+
augmented_seqs.append(aug_seq)
|
257 |
+
augmented_labels.append(label)
|
258 |
+
train_seqs.extend(augmented_seqs)
|
259 |
+
train_labels.extend(augmented_labels)
|
260 |
+
print("数据增强后训练集样本数:", len(train_seqs))
|
261 |
+
|
262 |
+
|
263 |
+
finetuned_model_file = additional_params.get("finetuned_model_file") if additional_params else None
|
264 |
+
# 创建 ProtT5Model 实例
|
265 |
+
prott5_model_instance = ProtT5Model(prott5_model_path, finetuned_model_file=finetuned_model_file)
|
266 |
+
|
267 |
+
def process_data(seqs_list): # Renamed seqs to seqs_list
|
268 |
+
feature_list = []
|
269 |
+
for s_item in seqs_list: # Renamed s to s_item
|
270 |
+
# 将 ProtT5Model 实例传递给 extract_features
|
271 |
+
features = extract_features(s_item, prott5_model_instance)
|
272 |
+
feature_list.append(features)
|
273 |
+
return np.array(feature_list)
|
274 |
+
|
275 |
+
X_train = process_data(train_seqs)
|
276 |
+
X_val = process_data(val_seqs)
|
277 |
+
|
278 |
+
if X_train.shape[0] == 0 or X_val.shape[0] == 0:
|
279 |
+
raise ValueError("特征提取后训练集或验证集为空。请检查序列数据和特征提取过程。")
|
280 |
+
|
281 |
+
|
282 |
+
# --- **关键修改:使用 RobustScaler** ---
|
283 |
+
# scaler = StandardScaler() # 原来的 StandardScaler
|
284 |
+
scaler = RobustScaler()
|
285 |
+
print("使用 RobustScaler 进行特征归一化。")
|
286 |
+
|
287 |
+
X_train_scaled = scaler.fit_transform(X_train)
|
288 |
+
X_val_scaled = scaler.transform(X_val)
|
289 |
+
|
290 |
+
return X_train_scaled, X_val_scaled, np.array(train_labels), np.array(val_labels), scaler
|
291 |
+
|
292 |
+
if __name__ == "__main__":
|
293 |
+
# 确保测试时使用的路径是有效的,或者创建虚拟文件
|
294 |
+
neg_fasta_test = "dummy_data/test_neg.fasta"
|
295 |
+
pos_fasta_test = "dummy_data/test_pos.fasta"
|
296 |
+
prott5_path_test = "dummy_prott5_model/" # 需要一个包含config.json, pytorch_model.bin等的目录结构
|
297 |
+
|
298 |
+
os.makedirs("dummy_data", exist_ok=True)
|
299 |
+
os.makedirs(prott5_path_test, exist_ok=True) # 创建虚拟模型目录
|
300 |
+
|
301 |
+
if not os.path.exists(neg_fasta_test):
|
302 |
+
with open(neg_fasta_test, "w") as f: f.write(">neg1\nKALKALKALK\n>neg2\nPEPTPEPT\n")
|
303 |
+
if not os.path.exists(pos_fasta_test):
|
304 |
+
with open(pos_fasta_test, "w") as f: f.write(">pos1\nAOPPEPTIDE\n>pos2\nTRYTRYTRY\n")
|
305 |
+
|
306 |
+
# 为了让ProtT5Model能加载,需要模拟一个最小的transformers模型目录结构
|
307 |
+
# 通常至少需要 config.json, pytorch_model.bin (或 tf_model.h5), tokenizer_config.json, spiece.model
|
308 |
+
# 这里我们只创建目录,实际加载可能会失败,除非transformers库能从模型名下载
|
309 |
+
# 或者您提供一个真实的本地ProtT5模型路径
|
310 |
+
if not os.listdir(prott5_path_test): # 如果目录为空
|
311 |
+
print(f"警告: {prott5_path_test} 为空。ProtT5Model可能尝试从HuggingFace Hub下载模型。")
|
312 |
+
print(f"请确保您已下载Rostlab/ProstT5-XL-UniRef50或类似模型到该路径,或使用其HuggingFace名称。")
|
313 |
+
# 作为演示,我们假设用户会提供一个有效的路径或transformers可以处理它
|
314 |
+
# 如果要完全本地运行而不下载,需要填充该目录。
|
315 |
+
|
316 |
+
additional_params_test = {
|
317 |
+
"augment": False,
|
318 |
+
"perturb_rate": 0.1,
|
319 |
+
"finetuned_model_file": None # 测试时不使用微调模型
|
320 |
+
}
|
321 |
+
|
322 |
+
print("开始测试 prepare_features (使用RobustScaler)...")
|
323 |
+
try:
|
324 |
+
X_train_t, X_val_t, y_train_t, y_val_t, scaler_t = prepare_features(
|
325 |
+
neg_fasta_test,
|
326 |
+
pos_fasta_test,
|
327 |
+
"Rostlab/ProstT5-XL-UniRef50", # 使用HuggingFace模型名称,如果本地路径无效
|
328 |
+
additional_params_test
|
329 |
+
)
|
330 |
+
print("prepare_features 测试完成。")
|
331 |
+
print("训练集样本数:", X_train_t.shape[0])
|
332 |
+
print("验证集样本数:", X_val_t.shape[0])
|
333 |
+
if X_train_t.shape[0] > 0:
|
334 |
+
print("训练集特征维度:", X_train_t.shape[1])
|
335 |
+
print("一个缩放后的训练样本 (前5个特征):", X_train_t[0, :5])
|
336 |
+
if scaler_t:
|
337 |
+
print("Scaler类型:", type(scaler_t))
|
338 |
+
except Exception as e:
|
339 |
+
print(f"prepare_features 测试失败: {e}")
|
340 |
+
print("这可能是由于无法加载ProtT5模型或FASTA文件处理问题。请检查路径和文件内容。")
|
341 |
+
|
generator_checkpoints_v3.6/final_generator_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d878934b93514142df36cb04c812af25ad53a79bdfa0a3edd9f65c22ba4e1d48
|
3 |
+
size 75775354
|
prott5/model/config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"T5ForConditionalGeneration"
|
4 |
+
],
|
5 |
+
"d_ff": 16384,
|
6 |
+
"d_kv": 128,
|
7 |
+
"d_model": 1024,
|
8 |
+
"decoder_start_token_id": 0,
|
9 |
+
"dropout_rate": 0.1,
|
10 |
+
"eos_token_id": 1,
|
11 |
+
"feed_forward_proj": "relu",
|
12 |
+
"initializer_factor": 1.0,
|
13 |
+
"is_encoder_decoder": true,
|
14 |
+
"layer_norm_epsilon": 1e-06,
|
15 |
+
"model_type": "t5",
|
16 |
+
"n_positions": 512,
|
17 |
+
"num_decoder_layers": 24,
|
18 |
+
"num_heads": 32,
|
19 |
+
"num_layers": 24,
|
20 |
+
"output_past": true,
|
21 |
+
"pad_token_id": 0,
|
22 |
+
"relative_attention_num_buckets": 32,
|
23 |
+
"use_cache": true,
|
24 |
+
"vocab_size": 128
|
25 |
+
}
|
prott5/model/finetuned_prott5.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3dbe43382d7c4aa0ba6de0ec02a9680155486a3e8984c82199ae8868223eba9e
|
3 |
+
size 75774969
|
prott5/model/gitattributes
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
prott5/model/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
|
prott5/model/spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74da7b4afcde53faa570114b530c726135bdfcdb813dec3abfb27f9d44db7324
|
3 |
+
size 237990
|
prott5/model/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": false}
|
requirements.txt
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py @ file:///home/conda/feedstock_root/build_artifacts/absl-py_1733730548347/work
|
2 |
+
accelerate==1.2.1
|
3 |
+
aiohappyeyeballs @ file:///home/conda/feedstock_root/build_artifacts/aiohappyeyeballs_1733331917983/work
|
4 |
+
aiohttp @ file:///home/conda/feedstock_root/build_artifacts/aiohttp_1734596887867/work
|
5 |
+
aiosignal @ file:///home/conda/feedstock_root/build_artifacts/aiosignal_1734342155601/work
|
6 |
+
alembic @ file:///home/conda/feedstock_root/build_artifacts/alembic_1733728406412/work
|
7 |
+
AmberUtils==21.0
|
8 |
+
astunparse @ file:///home/conda/feedstock_root/build_artifacts/astunparse_1728923142236/work
|
9 |
+
async-timeout @ file:///home/conda/feedstock_root/build_artifacts/async-timeout_1733235340728/work
|
10 |
+
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1734348785146/work
|
11 |
+
biopython==1.79
|
12 |
+
blinker @ file:///home/conda/feedstock_root/build_artifacts/blinker_1731096409132/work
|
13 |
+
boto3==1.35.82
|
14 |
+
botocore==1.35.82
|
15 |
+
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1725267488082/work
|
16 |
+
cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
|
17 |
+
cachetools @ file:///home/conda/feedstock_root/build_artifacts/cachetools_1733624221485/work
|
18 |
+
catboost @ https://pypi.org/packages/cp39/c/catboost/catboost-1.2.7-cp39-cp39-manylinux2014_x86_64.whl#sha256=e58cf8966e33931acebffbc744cf640e8abd08d0fdfb0e503c107552cea6c643
|
19 |
+
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1739515848642/work/certifi
|
20 |
+
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725571112467/work
|
21 |
+
chardet @ file:///home/conda/feedstock_root/build_artifacts/chardet_1724954807150/work
|
22 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1733218092148/work
|
23 |
+
click @ file:///home/conda/feedstock_root/build_artifacts/click_1733221831880/work
|
24 |
+
cloudpickle @ file:///home/conda/feedstock_root/build_artifacts/cloudpickle_1736947526808/work
|
25 |
+
cmake==3.31.2
|
26 |
+
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
|
27 |
+
colorlog @ file:///home/conda/feedstock_root/build_artifacts/colorlog_1733258404285/work
|
28 |
+
contourpy @ file:///home/conda/feedstock_root/build_artifacts/contourpy_1727293517607/work
|
29 |
+
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1732745941597/work
|
30 |
+
cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1733332471406/work
|
31 |
+
edgembar==3.0
|
32 |
+
filelock==3.16.1
|
33 |
+
Flask @ file:///home/conda/feedstock_root/build_artifacts/flask_1741793020411/work
|
34 |
+
flatbuffers @ file:///home/conda/feedstock_root/build_artifacts/python-flatbuffers_1733838640534/work
|
35 |
+
fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1733908950378/work
|
36 |
+
freetype-py @ file:///home/conda/feedstock_root/build_artifacts/freetype-py_1650983368720/work
|
37 |
+
frozenlist @ file:///home/conda/feedstock_root/build_artifacts/frozenlist_1729699456009/work
|
38 |
+
fsspec==2024.10.0
|
39 |
+
gast @ file:///home/conda/feedstock_root/build_artifacts/gast_1596839682936/work
|
40 |
+
google-auth @ file:///home/conda/feedstock_root/build_artifacts/google-auth_1734176111475/work
|
41 |
+
google-auth-oauthlib==1.0.0
|
42 |
+
google-pasta @ file:///home/conda/feedstock_root/build_artifacts/google-pasta_1733852499742/work
|
43 |
+
graphviz @ file:///home/conda/feedstock_root/build_artifacts/python-graphviz_1733791968395/work
|
44 |
+
greenlet @ file:///home/conda/feedstock_root/build_artifacts/greenlet_1734532792566/work
|
45 |
+
grpcio==1.68.1
|
46 |
+
gym @ file:///home/conda/feedstock_root/build_artifacts/gym_1673317490724/work
|
47 |
+
gym-notices @ file:///home/conda/feedstock_root/build_artifacts/gym-notices_1734709658541/work
|
48 |
+
h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1733298745555/work
|
49 |
+
h5py @ file:///home/conda/feedstock_root/build_artifacts/h5py_1734544937668/work
|
50 |
+
hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1733299205993/work
|
51 |
+
huggingface-hub==0.26.5
|
52 |
+
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1733298771451/work
|
53 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
|
54 |
+
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1733223117029/work
|
55 |
+
importlib_resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1733231327252/work
|
56 |
+
itsdangerous @ file:///home/conda/feedstock_root/build_artifacts/itsdangerous_1733308265247/work
|
57 |
+
Jinja2==3.1.4
|
58 |
+
jmespath==1.0.1
|
59 |
+
joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1733736026804/work
|
60 |
+
keras @ file:///home/conda/feedstock_root/build_artifacts/keras_1669020828024/work/keras-2.11.0-py2.py3-none-any.whl
|
61 |
+
Keras-Preprocessing @ file:///home/conda/feedstock_root/build_artifacts/keras-preprocessing_1610713559828/work
|
62 |
+
kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1725459266648/work
|
63 |
+
lightgbm @ file:///home/conda/feedstock_root/build_artifacts/liblightgbm_1728547676427/work
|
64 |
+
lit==18.1.8
|
65 |
+
llvmlite==0.43.0
|
66 |
+
lmdb==1.5.1
|
67 |
+
Mako @ file:///home/conda/feedstock_root/build_artifacts/mako_1733628147227/work
|
68 |
+
Markdown==3.7
|
69 |
+
markdown-it-py @ file:///home/conda/feedstock_root/build_artifacts/markdown-it-py_1733250460757/work
|
70 |
+
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1733219680183/work
|
71 |
+
matplotlib==3.9.4
|
72 |
+
mdurl @ file:///home/conda/feedstock_root/build_artifacts/mdurl_1733255585584/work
|
73 |
+
ml-dtypes @ file:///home/conda/feedstock_root/build_artifacts/ml_dtypes_1725475048616/work
|
74 |
+
MMPBSA.py==16.0
|
75 |
+
mordred==1.2.0
|
76 |
+
mpmath==1.3.0
|
77 |
+
multidict @ file:///home/conda/feedstock_root/build_artifacts/multidict_1733913043842/work
|
78 |
+
namex @ file:///home/conda/feedstock_root/build_artifacts/namex_1733858070416/work
|
79 |
+
ndfes==3.0
|
80 |
+
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work
|
81 |
+
numba @ file:///home/conda/feedstock_root/build_artifacts/numba_1718888028049/work
|
82 |
+
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1682210190296/work
|
83 |
+
nvidia-cublas-cu11==11.10.3.66
|
84 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
85 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
86 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
87 |
+
nvidia-cudnn-cu11==8.5.0.96
|
88 |
+
nvidia-cufft-cu11==10.9.0.58
|
89 |
+
nvidia-curand-cu11==10.2.10.91
|
90 |
+
nvidia-cusolver-cu11==11.4.0.1
|
91 |
+
nvidia-cusparse-cu11==11.7.4.91
|
92 |
+
nvidia-nccl-cu11==2.14.3
|
93 |
+
nvidia-nvtx-cu11==11.7.91
|
94 |
+
oauthlib @ file:///home/conda/feedstock_root/build_artifacts/oauthlib_1733752848439/work
|
95 |
+
opt_einsum @ file:///home/conda/feedstock_root/build_artifacts/opt_einsum_1733687912731/work
|
96 |
+
optree @ file:///home/conda/feedstock_root/build_artifacts/optree_1731510511426/work
|
97 |
+
optuna @ file:///home/conda/feedstock_root/build_artifacts/optuna_1734927470522/work
|
98 |
+
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1733203243479/work
|
99 |
+
packmol_memgen==2025.1.29
|
100 |
+
pandas @ file:///home/conda/feedstock_root/build_artifacts/pandas_1726878406085/work
|
101 |
+
patsy @ file:///home/conda/feedstock_root/build_artifacts/patsy_1733792384640/work
|
102 |
+
pdb4amber==22.0
|
103 |
+
peft==0.4.0
|
104 |
+
peptides @ file:///opt/conda/conda-bld/peptides_1726747256415/work
|
105 |
+
pfeature==1.4
|
106 |
+
Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1688255839723/work
|
107 |
+
plotly @ file:///home/conda/feedstock_root/build_artifacts/plotly_1733733072265/work
|
108 |
+
propcache @ file:///home/conda/feedstock_root/build_artifacts/propcache_1733391807885/work
|
109 |
+
protobuf==5.29.1
|
110 |
+
psutil==6.1.0
|
111 |
+
pyasn1 @ file:///home/conda/feedstock_root/build_artifacts/pyasn1_1733217608156/work
|
112 |
+
pyasn1_modules @ file:///home/conda/feedstock_root/build_artifacts/pyasn1-modules_1733324602540/work
|
113 |
+
pycairo==1.25.0
|
114 |
+
pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
|
115 |
+
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1733221634316/work
|
116 |
+
PyJWT @ file:///home/conda/feedstock_root/build_artifacts/pyjwt_1732782409051/work
|
117 |
+
pyMSMT==22.0
|
118 |
+
pynndescent @ file:///home/conda/feedstock_root/build_artifacts/pynndescent_1734193530377/work
|
119 |
+
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1733777368795/work
|
120 |
+
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1733222594562/work
|
121 |
+
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
|
122 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
|
123 |
+
pytraj==2.0.6
|
124 |
+
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1706886791323/work
|
125 |
+
pyu2f @ file:///home/conda/feedstock_root/build_artifacts/pyu2f_1733738580568/work
|
126 |
+
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1725456176299/work
|
127 |
+
regex==2024.11.6
|
128 |
+
reportlab @ file:///home/conda/feedstock_root/build_artifacts/reportlab_1727846838921/work
|
129 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1733217035951/work
|
130 |
+
requests-oauthlib @ file:///home/conda/feedstock_root/build_artifacts/requests-oauthlib_1733772243268/work
|
131 |
+
rich @ file:///home/conda/feedstock_root/build_artifacts/rich_1733342254348/work/dist
|
132 |
+
rlPyCairo @ file:///home/conda/feedstock_root/build_artifacts/rlpycairo_1687519531733/work
|
133 |
+
rsa @ file:///home/conda/feedstock_root/build_artifacts/rsa_1733662684165/work
|
134 |
+
s3transfer==0.10.4
|
135 |
+
safetensors==0.4.5
|
136 |
+
sander==22.0
|
137 |
+
scikit-learn @ file:///home/conda/feedstock_root/build_artifacts/scikit-learn_1733760110140/work/dist/scikit_learn-1.6.0-cp39-cp39-linux_x86_64.whl#sha256=f4205b3762b8033782da7a40727c959075158f152dcb3f50e33f60ecfc8cf6f4
|
138 |
+
scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy-split_1716470218293/work/dist/scipy-1.13.1-cp39-cp39-linux_x86_64.whl#sha256=e6696cb8683d94467891b7648e068a3970f6bc0a1b3c1aa7f9bc89458eafd2f0
|
139 |
+
seaborn @ file:///home/conda/feedstock_root/build_artifacts/seaborn-split_1733730015268/work
|
140 |
+
sentencepiece @ file:///home/conda/feedstock_root/build_artifacts/sentencepiece-split_1674200849009/work/python
|
141 |
+
shap @ file:///home/conda/feedstock_root/build_artifacts/shap_1732083651745/work
|
142 |
+
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
|
143 |
+
slicer @ file:///home/conda/feedstock_root/build_artifacts/slicer_1710029110974/work
|
144 |
+
SQLAlchemy @ file:///home/conda/feedstock_root/build_artifacts/sqlalchemy_1729066347171/work
|
145 |
+
statsmodels @ file:///home/conda/feedstock_root/build_artifacts/statsmodels_1727986706423/work
|
146 |
+
sympy==1.13.3
|
147 |
+
tape-proteins==0.5
|
148 |
+
tenacity @ file:///home/conda/feedstock_root/build_artifacts/tenacity_1733649050774/work
|
149 |
+
tensorboard==2.13.0
|
150 |
+
tensorboard-data-server==0.7.2
|
151 |
+
tensorboard-plugin-wit @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-plugin-wit_1641458951060/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
|
152 |
+
tensorboardX==2.6.2.2
|
153 |
+
tensorflow @ file:///home/conda/feedstock_root/build_artifacts/tensorflow-split_1679857771045/work/tensorflow_pkg/tensorflow-2.11.1-cp39-cp39-linux_x86_64.whl
|
154 |
+
tensorflow-estimator @ file:///home/conda/feedstock_root/build_artifacts/tensorflow-split_1679857771045/work/tensorflow-estimator/wheel_dir/tensorflow_estimator-2.11.0-py2.py3-none-any.whl
|
155 |
+
termcolor @ file:///home/conda/feedstock_root/build_artifacts/termcolor_1733754631889/work
|
156 |
+
threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1714400101435/work
|
157 |
+
tokenizers==0.13.3
|
158 |
+
torch==2.0.1
|
159 |
+
tqdm==4.65.0
|
160 |
+
transformers==4.30.0
|
161 |
+
triton==2.0.0
|
162 |
+
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work
|
163 |
+
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1733235305708/work
|
164 |
+
umap-learn @ file:///home/conda/feedstock_root/build_artifacts/umap-learn_1730221913988/work
|
165 |
+
unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1729704563364/work
|
166 |
+
urllib3==1.26.20
|
167 |
+
Werkzeug @ file:///home/conda/feedstock_root/build_artifacts/werkzeug_1733160440960/work
|
168 |
+
wrapt @ file:///home/conda/feedstock_root/build_artifacts/wrapt_1732523603052/work
|
169 |
+
xgboost @ file:///home/conda/feedstock_root/build_artifacts/xgboost-split_1733179637554/work/python-package
|
170 |
+
yarl @ file:///home/conda/feedstock_root/build_artifacts/yarl_1733428811664/work
|
171 |
+
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
|
172 |
+
zstandard==0.23.0
|