chshan commited on
Commit
7b2918a
·
verified ·
1 Parent(s): 4500204

Upload 13 files

Browse files
antioxidant_predictor_5.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class AntioxidantPredictor(nn.Module):
8
+ def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
9
+ super(AntioxidantPredictor, self).__init__()
10
+ self.prott5_dim = 1024
11
+ self.handcrafted_dim = input_dim - self.prott5_dim
12
+ self.seq_len = 16
13
+ self.prott5_feature_dim = 64 # 16 * 64 = 1024
14
+
15
+ encoder_layer = nn.TransformerEncoderLayer(
16
+ d_model=self.prott5_feature_dim,
17
+ nhead=transformer_heads,
18
+ dropout=transformer_dropout,
19
+ batch_first=True
20
+ )
21
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
22
+
23
+ fused_dim = self.prott5_feature_dim + self.handcrafted_dim
24
+ self.fusion_fc = nn.Sequential(
25
+ nn.Linear(fused_dim, 1024),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.3),
28
+ nn.Linear(1024, 512),
29
+ nn.ReLU(),
30
+ nn.Dropout(0.3)
31
+ )
32
+
33
+ self.classifier = nn.Sequential(
34
+ nn.Linear(512, 256),
35
+ nn.ReLU(),
36
+ nn.Dropout(0.3),
37
+ nn.Linear(256, 1)
38
+ )
39
+
40
+ # 温度缩放参数 T
41
+ # 初始化为1.0,表示在校准前不改变logits
42
+ # requires_grad=False,因为T通常在模型训练完成后单独优化
43
+ self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
44
+
45
+ def forward(self, x, *args):
46
+ batch_size = x.size(0)
47
+ prot_t5_features = x[:, :self.prott5_dim]
48
+ handcrafted_features = x[:, self.prott5_dim:]
49
+
50
+ prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
51
+ encoded_seq = self.transformer_encoder(prot_t5_seq)
52
+ refined_prott5 = encoded_seq.mean(dim=1)
53
+
54
+ fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
55
+ fused_features = self.fusion_fc(fused_features)
56
+
57
+ logits = self.classifier(fused_features)
58
+
59
+ # 应用温度缩放: logits / T
60
+ # 注意:这里是在获取原始logits后,外部应用sigmoid前进行缩放
61
+ # 如果要直接输出校准后的概率,可以在这里除以T然后sigmoid
62
+ # 但通常T的优化和应用是分离的。
63
+ # 为了在调用模型时就能获得校准的logits(如果T已优化),我们在这里应用它。
64
+ # 如果T未被优化(仍为1),则此操作无影响。
65
+ logits_scaled = logits / self.temperature
66
+
67
+ return logits_scaled # 返回校准后(或原始,如果T=1)的logits
68
+
69
+ def set_temperature(self, temp_value, device):
70
+ """用于设置优化后的温度值"""
71
+ self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
72
+ print(f"模型温度 T 设置为: {self.temperature.item()}")
73
+
74
+ def get_temperature(self):
75
+ """获取当前温度值"""
76
+ return self.temperature.item()
77
+
78
+ if __name__ == "__main__":
79
+ dummy_input = torch.randn(8, 1914)
80
+ model = AntioxidantPredictor(input_dim=1914)
81
+
82
+ print(f"初始温度: {model.get_temperature()}")
83
+ logits_output_initial = model(dummy_input)
84
+ print("初始 logits shape:", logits_output_initial.shape)
85
+ probs_initial = torch.sigmoid(logits_output_initial)
86
+ print("初始概率 (T=1.0):", probs_initial.detach().cpu().numpy()[:2])
87
+
88
+ # 模拟设置一个优化后的温度
89
+ model.set_temperature(1.5, device='cpu') # 假设优化得到 T=1.5
90
+ print(f"设置后温度: {model.get_temperature()}")
91
+ logits_output_scaled = model(dummy_input) # 模型内部应用了 T
92
+ print("缩放后 logits shape:", logits_output_scaled.shape)
93
+ probs_scaled = torch.sigmoid(logits_output_scaled) # 外部仍然需要 sigmoid
94
+ print("缩放后概率 (T=1.5):", probs_scaled.detach().cpu().numpy()[:2])
95
+
96
+ # 验证 logits / T 的效果
97
+ # logits_manual_scale = logits_output_initial / 1.5
98
+ # probs_manual_scale = torch.sigmoid(logits_manual_scale)
99
+ # print("手动缩放后概率 (T=1.5):", probs_manual_scale.detach().cpu().numpy()[:2])
100
+ # assert torch.allclose(probs_scaled, probs_manual_scale) # 应该相等
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - RLAnOxPeptide Gradio Web Application
2
+ # This script integrates both the predictor and generator into a user-friendly web UI.
3
+
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import pandas as pd
8
+ import joblib
9
+ import numpy as np
10
+ import gradio as gr
11
+ from sklearn.cluster import KMeans
12
+ from tqdm import tqdm
13
+ import transformers
14
+
15
+ # Suppress verbose logging from transformers
16
+ transformers.logging.set_verbosity_error()
17
+
18
+ # --------------------------------------------------------------------------
19
+ # SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
20
+ # To make this app self-contained, we copy necessary class definitions here.
21
+ # These should match the versions used during training.
22
+ # --------------------------------------------------------------------------
23
+
24
+ # --- Vocabulary Definition (from both scripts) ---
25
+ AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
26
+ token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
27
+ token2id["<PAD>"] = 0
28
+ token2id["<EOS>"] = 1
29
+ id2token = {i: t for t, i in token2id.items()}
30
+ VOCAB_SIZE = len(token2id)
31
+
32
+ # --- Predictor Model Architecture (from antioxidant_predictor_5.py) ---
33
+ class AntioxidantPredictor(nn.Module):
34
+ # This class definition should be an exact copy from your project
35
+ def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
36
+ super(AntioxidantPredictor, self).__init__()
37
+ self.input_dim = input_dim
38
+ self.t5_dim = 1024
39
+ self.hand_crafted_dim = self.input_dim - self.t5_dim
40
+
41
+ encoder_layer = nn.TransformerEncoderLayer(
42
+ d_model=self.t5_dim, nhead=transformer_heads,
43
+ dropout=transformer_dropout, batch_first=True
44
+ )
45
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
46
+
47
+ self.mlp = nn.Sequential(
48
+ nn.Linear(self.input_dim, 512),
49
+ nn.ReLU(),
50
+ nn.Dropout(0.5),
51
+ nn.Linear(512, 256),
52
+ nn.ReLU(),
53
+ nn.Dropout(0.5),
54
+ nn.Linear(256, 1)
55
+ )
56
+ self.temperature = nn.Parameter(torch.ones(1))
57
+
58
+ def forward(self, fused_features):
59
+ tr_features = fused_features[:, :self.t5_dim]
60
+ hand_features = fused_features[:, self.t5_dim:]
61
+ tr_features_unsqueezed = tr_features.unsqueeze(1)
62
+ transformer_output = self.transformer_encoder(tr_features_unsqueezed)
63
+ transformer_output_pooled = transformer_output.mean(dim=1)
64
+ combined_features = torch.cat((transformer_output_pooled, hand_features), dim=1)
65
+ logits = self.mlp(combined_features)
66
+ return logits / self.temperature
67
+
68
+ def get_temperature(self):
69
+ return self.temperature.item()
70
+
71
+ # --- Generator Model Architecture (from generator.py) ---
72
+ class ProtT5Generator(nn.Module):
73
+ # This class definition should be an exact copy from your project
74
+ def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
75
+ super(ProtT5Generator, self).__init__()
76
+ self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=token2id["<PAD>"])
77
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True)
78
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
79
+ self.lm_head = nn.Linear(embed_dim, vocab_size)
80
+ self.vocab_size = vocab_size
81
+ self.eos_token_id = token2id["<EOS>"]
82
+ self.pad_token_id = token2id["<PAD>"]
83
+
84
+ def forward(self, input_ids):
85
+ embeddings = self.embed_tokens(input_ids)
86
+ encoder_output = self.encoder(embeddings)
87
+ logits = self.lm_head(encoder_output)
88
+ return logits
89
+
90
+ def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
91
+ start_token = torch.randint(2, self.vocab_size, (batch_size, 1), device=device)
92
+ generated = start_token
93
+ for _ in range(max_length - 1):
94
+ logits = self.forward(generated)
95
+ next_logits = logits[:, -1, :] / temperature
96
+ if generated.size(1) < min_decoded_length:
97
+ next_logits[:, self.eos_token_id] = -float("inf")
98
+
99
+ probs = torch.softmax(next_logits, dim=-1)
100
+ next_token = torch.multinomial(probs, num_samples=1)
101
+ generated = torch.cat((generated, next_token), dim=1)
102
+
103
+ if (next_token == self.eos_token_id).all():
104
+ break
105
+ return generated
106
+
107
+ def decode(self, token_ids_batch):
108
+ seqs = []
109
+ for ids_tensor in token_ids_batch:
110
+ seq = ""
111
+ # Start from index 1 to skip the initial random start token
112
+ for token_id in ids_tensor.tolist()[1:]:
113
+ if token_id == self.eos_token_id: break
114
+ if token_id == self.pad_token_id: continue
115
+ seq += id2token.get(token_id, "?")
116
+ seqs.append(seq)
117
+ return seqs
118
+
119
+ # --- Feature Extraction Logic (from feature_extract.py) ---
120
+ # Note: You need the actual ProtT5Model and extract_features here.
121
+ # Assuming they are in a file named `feature_extract.py` in the same directory.
122
+ try:
123
+ from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
124
+ except ImportError:
125
+ raise gr.Error("Failed to import feature_extract.py. Please ensure the file is in the same directory as app.py.")
126
+
127
+ # --- Clustering Logic (from generator.py) ---
128
+ def cluster_sequences(generator, sequences, num_clusters, device):
129
+ if not sequences or len(sequences) < num_clusters:
130
+ return sequences[:num_clusters]
131
+ with torch.no_grad():
132
+ token_ids_list = []
133
+ max_len = max(len(seq) for seq in sequences) + 2 # Start token + EOS
134
+ for seq in sequences:
135
+ # Recreate encoding to match how generator sees it (with start token)
136
+ ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
137
+ ids = [np.random.randint(2, VOCAB_SIZE)] + ids # Add a dummy start token
138
+ ids += [token2id["<PAD>"]] * (max_len - len(ids))
139
+ token_ids_list.append(ids)
140
+
141
+ input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
142
+ embeddings = generator.embed_tokens(input_ids)
143
+ mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
144
+ embeddings = embeddings * mask
145
+ lengths = mask.sum(dim=1)
146
+ seq_embeds = embeddings.sum(dim=1) / (lengths + 1e-9)
147
+ seq_embeds_np = seq_embeds.cpu().numpy()
148
+
149
+ kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
150
+ representatives = []
151
+ for i in range(int(num_clusters)):
152
+ indices = np.where(kmeans.labels_ == i)[0]
153
+ if len(indices) == 0: continue
154
+ cluster_center = kmeans.cluster_centers_[i]
155
+ cluster_embeddings = seq_embeds_np[indices]
156
+ distances = np.linalg.norm(cluster_embeddings - cluster_center, axis=1)
157
+ representative_index = indices[np.argmin(distances)]
158
+ representatives.append(sequences[representative_index])
159
+ return representatives
160
+
161
+
162
+ # --------------------------------------------------------------------------
163
+ # SECTION 2: GLOBAL MODEL LOADING
164
+ # Load all models and dependencies once when the app starts.
165
+ # --------------------------------------------------------------------------
166
+ print("Loading all models and dependencies. Please wait...")
167
+ DEVICE = "cpu" # Use CPU for compatibility with Hugging Face free tier
168
+
169
+ try:
170
+ # --- Define all required file paths here ---
171
+ # !! IMPORTANT: Ensure these are relative paths to the files in your Space !!
172
+ PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
173
+ SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
174
+ GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
175
+ PROTT5_BASE_MODEL_PATH = "prott5/model/"
176
+ FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
177
+
178
+ # --- Load Predictor Components ---
179
+ print("Loading Predictor Model...")
180
+ PREDICTOR_MODEL = AntioxidantPredictor(
181
+ input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
182
+ )
183
+ PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
184
+ PREDICTOR_MODEL.to(DEVICE)
185
+ PREDICTOR_MODEL.eval()
186
+ print("✅ Predictor model loaded.")
187
+
188
+ print("Loading Scaler...")
189
+ SCALER = joblib.load(SCALER_PATH)
190
+ print("✅ Scaler loaded.")
191
+
192
+ print("Loading ProtT5 Feature Extractor...")
193
+ # This extractor must use the fine-tuned model for features, as per your training logic
194
+ PROTT5_EXTRACTOR = FeatureProtT5Model(
195
+ model_path=PROTT5_BASE_MODEL_PATH,
196
+ finetuned_model_file=FINETUNED_PROTT5_FOR_FEATURES_PATH
197
+ )
198
+ print("✅ ProtT5 Feature Extractor loaded.")
199
+
200
+ # --- Load Generator Model ---
201
+ print("Loading Generator Model...")
202
+ GENERATOR_MODEL = ProtT5Generator(
203
+ vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1
204
+ )
205
+ GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
206
+ GENERATOR_MODEL.to(DEVICE)
207
+ GENERATOR_MODEL.eval()
208
+ print("✅ Generator model loaded.")
209
+
210
+ print("\n--- All models loaded successfully! Gradio app is ready. ---\n")
211
+
212
+ except Exception as e:
213
+ print(f"💥 FATAL ERROR: Failed to load a model or dependency file: {e}")
214
+ raise gr.Error(f"Model or dependency loading failed! Check file paths and integrity. Error: {e}")
215
+
216
+ # --------------------------------------------------------------------------
217
+ # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO
218
+ # These functions connect the UI to our model's logic.
219
+ # --------------------------------------------------------------------------
220
+
221
+ def predict_peptide_wrapper(sequence_str):
222
+ """Takes a peptide sequence string and returns its predicted probability and class."""
223
+ if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
224
+ return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
225
+
226
+ try:
227
+ # 1. Extract features using the same logic as training/prediction scripts
228
+ features = extract_features(sequence_str, PROTT5_EXTRACTOR)
229
+
230
+ # 2. Scale features
231
+ scaled_features = SCALER.transform(features.reshape(1, -1))
232
+
233
+ # 3. Predict with the model
234
+ with torch.no_grad():
235
+ features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
236
+ logits = PREDICTOR_MODEL(features_tensor)
237
+ probability = torch.sigmoid(logits).squeeze().item()
238
+
239
+ classification = "Antioxidant" if probability >= 0.5 else "Non-Antioxidant"
240
+ return f"{probability:.4f}", classification
241
+
242
+ except Exception as e:
243
+ print(f"Prediction error for sequence '{sequence_str}': {e}")
244
+ return "N/A", f"An error occurred during processing: {e}"
245
+
246
+ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
247
+ """Generates, validates, and clusters sequences."""
248
+ num_to_generate = int(num_to_generate)
249
+ min_len = int(min_len)
250
+ max_len = int(max_len)
251
+
252
+ try:
253
+ # STEP 1: Generate an initial pool of unique sequences
254
+ target_pool_size = int(num_to_generate * diversity_factor)
255
+ unique_seqs = set()
256
+ progress(0, desc="Generating initial peptide pool...")
257
+
258
+ max_attempts = 10
259
+ attempts = 0
260
+ while len(unique_seqs) < target_pool_size and attempts < max_attempts:
261
+ batch_size = (target_pool_size - len(unique_seqs)) * 2 # Generate extra to account for duplicates/short ones
262
+ with torch.no_grad():
263
+ generated_tokens = GENERATOR_MODEL.sample(
264
+ batch_size=batch_size,
265
+ max_length=max_len,
266
+ device=DEVICE,
267
+ temperature=temperature,
268
+ min_decoded_length=min_len
269
+ )
270
+ decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
271
+ for seq in decoded:
272
+ if min_len <= len(seq) <= max_len:
273
+ unique_seqs.add(seq)
274
+ attempts += 1
275
+ progress(len(unique_seqs) / target_pool_size, desc=f"Generated {len(unique_seqs)} unique sequences...")
276
+
277
+ candidate_seqs = list(unique_seqs)
278
+ if not candidate_seqs:
279
+ return pd.DataFrame({"Sequence": ["Failed to generate valid sequences."], "Predicted Probability": ["N/A"]})
280
+
281
+ # STEP 2: Validate the generated sequences
282
+ validated_pool = {}
283
+ for seq in tqdm(candidate_seqs, desc="Validating generated sequences"):
284
+ prob_str, _ = predict_peptide_wrapper(seq)
285
+ try:
286
+ prob = float(prob_str)
287
+ if prob > 0.90: # Filter for high-quality peptides as in generator.py
288
+ validated_pool[seq] = prob
289
+ except (ValueError, TypeError):
290
+ continue
291
+
292
+ if not validated_pool:
293
+ return pd.DataFrame({"Sequence": ["No high-activity peptides (>0.9 prob) were generated."], "Predicted Probability": ["N/A"]})
294
+
295
+ high_quality_sequences = list(validated_pool.keys())
296
+
297
+ # STEP 3: Cluster to ensure diversity
298
+ progress(1.0, desc="Clustering for diversity...")
299
+ final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
300
+
301
+ # STEP 4: Format final results
302
+ final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
303
+ final_results.sort(key=lambda x: float(x[1]), reverse=True)
304
+
305
+ return pd.DataFrame(final_results, columns=["Sequence", "Predicted Probability"])
306
+
307
+ except Exception as e:
308
+ print(f"Generation error: {e}")
309
+ return pd.DataFrame({"Sequence": [f"An error occurred during generation: {e}"], "Predicted Probability": ["N/A"]})
310
+
311
+
312
+ # --------------------------------------------------------------------------
313
+ # SECTION 4: GRADIO UI CONSTRUCTION
314
+ # Building the web interface. All text is in English.
315
+ # --------------------------------------------------------------------------
316
+ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
317
+ gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
318
+ gr.Markdown("An integrated framework combining reinforcement learning and a Transformer model for the efficient prediction and innovative design of antioxidant peptides.")
319
+
320
+ with gr.Tabs():
321
+ with gr.TabItem("Peptide Activity Predictor"):
322
+ gr.Markdown("### Enter an amino acid sequence to predict its antioxidant activity.")
323
+ with gr.Row():
324
+ peptide_input = gr.Textbox(label="Peptide Sequence", placeholder="e.g., WHYHDYKY", scale=3)
325
+ predict_button = gr.Button("Predict", variant="primary", scale=1)
326
+ with gr.Row():
327
+ probability_output = gr.Textbox(label="Predicted Probability")
328
+ class_output = gr.Textbox(label="Predicted Class")
329
+
330
+ predict_button.click(
331
+ fn=predict_peptide_wrapper,
332
+ inputs=peptide_input,
333
+ outputs=[probability_output, class_output]
334
+ )
335
+ gr.Examples(
336
+ examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"], ["INVALIDSEQUENCE"]],
337
+ inputs=peptide_input,
338
+ outputs=[probability_output, class_output],
339
+ fn=predict_peptide_wrapper,
340
+ cache_examples=False,
341
+ )
342
+
343
+ with gr.TabItem("Novel Sequence Generator"):
344
+ gr.Markdown("### Set parameters to generate novel, high-activity antioxidant peptides.")
345
+ with gr.Column():
346
+ with gr.Row():
347
+ num_input = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
348
+ min_len_input = gr.Slider(minimum=2, maximum=10, value=3, step=1, label="Minimum Length")
349
+ max_len_input = gr.Slider(minimum=10, maximum=20, value=20, step=1, label="Maximum Length")
350
+ with gr.Row():
351
+ temp_input = gr.Slider(minimum=0.5, maximum=3.0, value=2.5, step=0.1, label="Temperature (Higher = More random)")
352
+ diversity_input = gr.Slider(minimum=1.0, maximum=3.0, value=1.2, step=0.1, label="Diversity Factor (Higher = Larger initial pool for clustering)")
353
+
354
+ generate_button = gr.Button("Generate Peptides", variant="primary")
355
+ results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides", wrap=True)
356
+
357
+ generate_button.click(
358
+ fn=generate_peptide_wrapper,
359
+ inputs=[num_input, min_len_input, max_len_input, temp_input, diversity_input],
360
+ outputs=results_output
361
+ )
362
+
363
+ if __name__ == "__main__":
364
+ demo.launch()
checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cb2c7cf6a59a5d028af4678ac5aec83305d3638bf9f43fee6df14f162f4a4e6
3
+ size 9930081
checkpoints/scaler_FINETUNED_PROTT5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50cfe7b0204a4e1dd5f22dc1440d6e8448f7ee323bb20957c70ce1058ca78475
3
+ size 31127
feature_extract.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ import numpy as np
5
+ import random
6
+ import pandas as pd
7
+ from Bio.SeqUtils.ProtParam import ProteinAnalysis
8
+ from sklearn.model_selection import train_test_split
9
+ # from sklearn.preprocessing import StandardScaler # 不再使用 StandardScaler
10
+ from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
11
+ import torch
12
+ from transformers import T5EncoderModel, T5Tokenizer
13
+
14
+ # ProtT5Model, load_fasta, load_fasta_with_labels,
15
+ # compute_amino_acid_composition, compute_reducing_aa_ratio,
16
+ # compute_physicochemical_properties, compute_electronic_features,
17
+ # compute_dimer_frequency, positional_encoding, perturb_sequence,
18
+ # generate_adversarial_samples, extract_features 函数与您之前提供的版本相同。
19
+ # 为保持简洁,此处省略这些函数的代码。请确保它们在您的文件中是完整的。
20
+ # 您可以从之前的日志或您本地的文件中复制这些函数。
21
+ # 以下是 prepare_features 函数的修改版,以及其他函数的占位符。
22
+
23
+ class ProtT5Model:
24
+ """
25
+ 从本地加载 ProtT5 模型。如果 finetuned_model_file 不为空,则加载微调后的权重(使用 strict=False)。
26
+ """
27
+ def __init__(self, model_path, finetuned_model_file=None):
28
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ # 尝试加载本地文件,如果失败,transformers库可能会尝试从hub下载(取决于配置)
30
+ try:
31
+ self.tokenizer = T5Tokenizer.from_pretrained(model_path, do_lower_case=False, local_files_only=True)
32
+ self.model = T5EncoderModel.from_pretrained(model_path, local_files_only=True)
33
+ except OSError: # OSError: Can't load tokenizer for '...'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure '...' is the correct path to a directory containing all relevant files for a T5Tokenizer tokenizer.
34
+ print(f"警告: 无法从本地路径 {model_path} 加载ProtT5模型/分词器。尝试从HuggingFace Hub下载(如果transformers配置允许)。")
35
+ self.tokenizer = T5Tokenizer.from_pretrained(model_path.split('/')[-1] if '/' in model_path else model_path, do_lower_case=False) # 尝试使用模型名下载
36
+ self.model = T5EncoderModel.from_pretrained(model_path.split('/')[-1] if '/' in model_path else model_path)
37
+
38
+
39
+ if finetuned_model_file is not None and os.path.exists(finetuned_model_file):
40
+ try:
41
+ state_dict = torch.load(finetuned_model_file, map_location=self.device)
42
+ missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
43
+ print(f"加载微调权重 {finetuned_model_file}:缺失键 {missing_keys}, 意外键 {unexpected_keys}")
44
+ except Exception as e:
45
+ print(f"加载微调权重 {finetuned_model_file} 失败: {e}")
46
+
47
+ self.model.to(self.device)
48
+ self.model.eval()
49
+
50
+ def encode(self, sequence):
51
+ if not sequence or not isinstance(sequence, str): # 增加对空序列或非字符串的检查
52
+ print(f"警告: ProtT5Model.encode 接收到无效序列: {sequence}")
53
+ # 返回一个零向量或根据需要处理错误
54
+ # 假设 ProtT5 输出维度为 1024 (embedding.shape[1])
55
+ # 假设序列处理后平均池化,所以返回 (1024,)
56
+ # 但 encode 返回的是 (seq_len, hidden_dim),所以这里返回一个模拟的短序列零嵌入
57
+ return np.zeros((1, 1024), dtype=np.float32) # (1, hidden_dim)
58
+
59
+ seq_spaced = " ".join(list(sequence)) # 修改变量名以避免覆盖外部seq
60
+ try:
61
+ encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True, max_length=1022) # ProtT5通常最大长度1024,tokenized后可能更长
62
+ except Exception as e:
63
+ print(f"分词失败序列 '{sequence[:30]}...': {e}")
64
+ return np.zeros((1, 1024), dtype=np.float32)
65
+
66
+ encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
67
+ with torch.no_grad():
68
+ try:
69
+ embedding = self.model(**encoded_input).last_hidden_state # (batch_size, seq_len, hidden_dim)
70
+ except Exception as e:
71
+ print(f"ProtT5模型推理失败序列 '{sequence[:30]}...': {e}")
72
+ return np.zeros((1, 1024), dtype=np.float32)
73
+
74
+ emb = embedding.squeeze(0).cpu().numpy() # (seq_len, hidden_dim)
75
+ if emb.shape[0] == 0: # 如果由于某种原因序列长度为0
76
+ return np.zeros((1, 1024), dtype=np.float32)
77
+ return emb
78
+
79
+ # --- (此处应包含您之前版本中所有其他的特征提取辅助函数) ---
80
+ # load_fasta, load_fasta_with_labels, compute_amino_acid_composition, ... extract_features
81
+ # 为确保完整性��请从您本地的 feature_extract.py 文件中复制这些函数到这里。
82
+ # 下面是这些函数的一个简化占位符,您需要用实际的函数替换它们。
83
+
84
+ def load_fasta(fasta_file):
85
+ # (您的 load_fasta 实现)
86
+ sequences = []
87
+ try:
88
+ with open(fasta_file, 'r') as f:
89
+ current_seq_lines = []
90
+ for line in f:
91
+ line = line.strip()
92
+ if not line: continue
93
+ if line.startswith(">"):
94
+ if current_seq_lines: sequences.append("".join(current_seq_lines))
95
+ current_seq_lines = []
96
+ else: current_seq_lines.append(line)
97
+ if current_seq_lines: sequences.append("".join(current_seq_lines))
98
+ except FileNotFoundError: print(f"文件未找到: {fasta_file}"); return []
99
+ return sequences
100
+
101
+ def load_fasta_with_labels(fasta_file):
102
+ # (您的 load_fasta_with_labels 实现)
103
+ sequences, labels = [], []
104
+ try:
105
+ with open(fasta_file, 'r') as f:
106
+ current_seq_lines, current_label = [], None
107
+ for line in f:
108
+ line = line.strip()
109
+ if not line: continue
110
+ if line.startswith(">"):
111
+ if current_seq_lines:
112
+ sequences.append("".join(current_seq_lines))
113
+ labels.append(current_label if current_label is not None else 0) # Default label 0
114
+ current_seq_lines = []
115
+ current_label = int(line[1]) if len(line) > 1 and line[1] in ['0', '1'] else 0
116
+ else: current_seq_lines.append(line)
117
+ if current_seq_lines:
118
+ sequences.append("".join(current_seq_lines))
119
+ labels.append(current_label if current_label is not None else 0)
120
+ except FileNotFoundError: print(f"文件未找到: {fasta_file}"); return [],[]
121
+ return sequences, labels
122
+
123
+
124
+ def compute_amino_acid_composition(seq):
125
+ if not seq: return {aa: 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
126
+ # (您的 compute_amino_acid_composition 实现)
127
+ amino_acids = "ACDEFGHIKLMNPQRSTVWY"
128
+ seq_len = len(seq)
129
+ return {aa: seq.upper().count(aa) / seq_len for aa in amino_acids}
130
+
131
+
132
+ def compute_reducing_aa_ratio(seq):
133
+ if not seq: return 0.0
134
+ # (您的 compute_reducing_aa_ratio 实现)
135
+ reducing = ['C', 'M', 'W']
136
+ return sum(seq.upper().count(aa) for aa in reducing) / len(seq) if len(seq) > 0 else 0.0
137
+
138
+ def compute_physicochemical_properties(seq):
139
+ if not seq or not all(c.upper() in "ACDEFGHIKLMNPQRSTVWYXUBZ" for c in seq): # ProteinAnalysis might fail on invalid chars
140
+ return 0.0, 0.0, 0.0 # Default values
141
+ try:
142
+ analysis = ProteinAnalysis(str(seq).upper().replace('X','A').replace('U','C').replace('B','N').replace('Z','Q')) # Replace non-standard with common ones for analysis
143
+ return analysis.gravy(), analysis.isoelectric_point(), analysis.molecular_weight()
144
+ except Exception: # Catch any error from ProteinAnalysis
145
+ return 0.0, 7.0, 110.0 * len(seq) # Rough defaults
146
+
147
+ def compute_electronic_features(seq):
148
+ if not seq: return 0.0, 0.0
149
+ # (您的 compute_electronic_features 实现)
150
+ electronegativity = {'A':1.8,'C':2.5,'D':3.0,'E':3.2,'F':2.8,'G':1.6,'H':2.4,'I':4.5,'K':3.0,'L':4.2,'M':4.5,'N':2.0,'P':3.5,'Q':3.5,'R':2.5,'S':1.8,'T':2.5,'V':4.0,'W':5.0,'Y':4.0}
151
+ values = [electronegativity.get(aa.upper(), 2.5) for aa in seq]
152
+ avg_val = sum(values) / len(values) if values else 2.5
153
+ return avg_val + 0.1, avg_val - 0.1
154
+
155
+
156
+ def compute_dimer_frequency(seq):
157
+ if len(seq) < 2: return np.zeros(400) # 20*20
158
+ # (您的 compute_dimer_frequency 实现)
159
+ amino_acids = "ACDEFGHIKLMNPQRSTVWY"
160
+ dimer_counts = {aa1+aa2: 0 for aa1 in amino_acids for aa2 in amino_acids}
161
+ for i in range(len(seq) - 1):
162
+ dimer = seq[i:i+2].upper()
163
+ if dimer in dimer_counts: dimer_counts[dimer] += 1
164
+ total = max(len(seq) - 1, 1)
165
+ for key in dimer_counts: dimer_counts[key] /= total
166
+ return np.array([dimer_counts[d] for d in sorted(dimer_counts.keys())])
167
+
168
+
169
+ def positional_encoding(seq_len_actual, L_fixed=29, d_model=16): # Pass actual sequence length or use L_fixed
170
+ # (您的 positional_encoding 实现)
171
+ # This PE is fixed length, not dependent on actual seq len if L_fixed is used.
172
+ # For random short sequences, this fixed PE might be an issue.
173
+ # A more dynamic PE or no PE for very short sequences might be better.
174
+ # However, to match current model input, we keep it.
175
+ pos_enc = np.zeros((L_fixed, d_model))
176
+ for pos in range(L_fixed):
177
+ for i in range(d_model):
178
+ angle = pos / (10000 ** (2 * (i // 2) / d_model))
179
+ pos_enc[pos, i] = np.sin(angle) if i % 2 == 0 else np.cos(angle)
180
+ return pos_enc.flatten()
181
+
182
+
183
+ def perturb_sequence(seq, perturb_rate=0.1, critical=['C', 'M', 'W']):
184
+ # (您的 perturb_sequence 实现)
185
+ if not seq: return ""
186
+ seq_list = list(seq)
187
+ amino_acids = "ACDEFGHIKLMNPQRSTVWY"
188
+ for i, aa in enumerate(seq_list):
189
+ if aa.upper() not in critical and random.random() < perturb_rate:
190
+ seq_list[i] = random.choice([x for x in amino_acids if x != aa.upper()])
191
+ return "".join(seq_list)
192
+
193
+
194
+ def extract_features(seq, prott5_model_instance, L_fixed=29, d_model_pe=16): # Renamed d_model to d_model_pe
195
+ if not seq or not isinstance(seq, str) or len(seq) == 0:
196
+ print(f"警告: extract_features 接收到空或无效序列。返回零特征。")
197
+ # 返回一个与预期特征维度匹配的零向量
198
+ # 1024 (protT5) + 20 (aac) + 1 (red_ratio) + 3 (phys) + 2 (elec) + 400 (dimer) + L_fixed*d_model_pe (pos_enc)
199
+ # Example: 1024 + 20 + 1 + 3 + 2 + 400 + 29*16 = 1024 + 20 + 1 + 3 + 2 + 400 + 464 = 1914
200
+ return np.zeros(1024 + 20 + 1 + 3 + 2 + 400 + (L_fixed * d_model_pe))
201
+
202
+
203
+ embedding = prott5_model_instance.encode(seq) # prott5_model is now an instance
204
+ prot_embed = np.mean(embedding, axis=0) if embedding.shape[0] > 0 else np.zeros(embedding.shape[1] if embedding.ndim > 1 else 1024) # Handle empty embedding
205
+ if prot_embed.shape[0] != 1024: # Ensure consistent ProtT5 embedding dim
206
+ # print(f"警告: ProtT5 嵌入维度异常 ({prot_embed.shape[0]}) for seq '{seq[:20]}...'. 使用零向量。")
207
+ prot_embed = np.zeros(1024)
208
+
209
+
210
+ aa_comp = compute_amino_acid_composition(seq)
211
+ aa_comp_vector = np.array([aa_comp[aa] for aa in "ACDEFGHIKLMNPQRSTVWY"])
212
+ red_ratio = np.array([compute_reducing_aa_ratio(seq)])
213
+ gravy, pI, mol_weight = compute_physicochemical_properties(seq)
214
+ phys_props = np.array([gravy, pI, mol_weight])
215
+ HOMO, LUMO = compute_electronic_features(seq)
216
+ electronic = np.array([HOMO, LUMO])
217
+ dimer_vector = compute_dimer_frequency(seq)
218
+ pos_enc = positional_encoding(len(seq), L_fixed, d_model_pe) # Pass actual length, though current PE uses L_fixed
219
+
220
+ features = np.concatenate([prot_embed, aa_comp_vector, red_ratio, phys_props, electronic, dimer_vector, pos_enc])
221
+ return features
222
+
223
+ ##############################################
224
+ # 主接口函数 prepare_features
225
+ ##############################################
226
+ def prepare_features(neg_fasta, pos_fasta, prott5_model_path, additional_params=None):
227
+ neg_seqs = load_fasta(neg_fasta)
228
+ pos_seqs = load_fasta(pos_fasta)
229
+
230
+ if not neg_seqs and not pos_seqs:
231
+ raise ValueError("未能从FASTA文件加载任何序列。请检查文件路径和内容。")
232
+
233
+ neg_labels = [0] * len(neg_seqs)
234
+ pos_labels = [1] * len(pos_seqs)
235
+ sequences = neg_seqs + pos_seqs
236
+ labels = neg_labels + pos_labels
237
+
238
+ combined = list(zip(sequences, labels))
239
+ random.shuffle(combined)
240
+ sequences, labels = zip(*combined)
241
+ sequences = list(sequences)
242
+ labels = list(labels)
243
+
244
+ train_seqs, val_seqs, train_labels, val_labels = train_test_split(
245
+ sequences, labels, test_size=0.1, random_state=42, stratify=labels if len(np.unique(labels)) > 1 else None
246
+ )
247
+ print("训练集原始样本数:", len(train_seqs))
248
+ print("验证集原始样本数:", len(val_seqs))
249
+
250
+ if additional_params is not None and additional_params.get("augment", False):
251
+ # (数据增强逻辑 - 如果启用)
252
+ augmented_seqs, augmented_labels = [], []
253
+ perturb_rate = additional_params.get("perturb_rate", 0.1)
254
+ for seq, label in zip(train_seqs, train_labels):
255
+ aug_seq = perturb_sequence(seq, perturb_rate=perturb_rate)
256
+ augmented_seqs.append(aug_seq)
257
+ augmented_labels.append(label)
258
+ train_seqs.extend(augmented_seqs)
259
+ train_labels.extend(augmented_labels)
260
+ print("数据增强后训练集样本数:", len(train_seqs))
261
+
262
+
263
+ finetuned_model_file = additional_params.get("finetuned_model_file") if additional_params else None
264
+ # 创建 ProtT5Model 实例
265
+ prott5_model_instance = ProtT5Model(prott5_model_path, finetuned_model_file=finetuned_model_file)
266
+
267
+ def process_data(seqs_list): # Renamed seqs to seqs_list
268
+ feature_list = []
269
+ for s_item in seqs_list: # Renamed s to s_item
270
+ # 将 ProtT5Model 实例传递给 extract_features
271
+ features = extract_features(s_item, prott5_model_instance)
272
+ feature_list.append(features)
273
+ return np.array(feature_list)
274
+
275
+ X_train = process_data(train_seqs)
276
+ X_val = process_data(val_seqs)
277
+
278
+ if X_train.shape[0] == 0 or X_val.shape[0] == 0:
279
+ raise ValueError("特征提取后训练集或验证集为空。请检查序列数据和特征提取过程。")
280
+
281
+
282
+ # --- **关键修改:使用 RobustScaler** ---
283
+ # scaler = StandardScaler() # 原来的 StandardScaler
284
+ scaler = RobustScaler()
285
+ print("使用 RobustScaler 进行特征归一化。")
286
+
287
+ X_train_scaled = scaler.fit_transform(X_train)
288
+ X_val_scaled = scaler.transform(X_val)
289
+
290
+ return X_train_scaled, X_val_scaled, np.array(train_labels), np.array(val_labels), scaler
291
+
292
+ if __name__ == "__main__":
293
+ # 确保测试时使用的路径是有效的,或者创建虚拟文件
294
+ neg_fasta_test = "dummy_data/test_neg.fasta"
295
+ pos_fasta_test = "dummy_data/test_pos.fasta"
296
+ prott5_path_test = "dummy_prott5_model/" # 需要一个包含config.json, pytorch_model.bin等的目录结构
297
+
298
+ os.makedirs("dummy_data", exist_ok=True)
299
+ os.makedirs(prott5_path_test, exist_ok=True) # 创建虚拟模型目录
300
+
301
+ if not os.path.exists(neg_fasta_test):
302
+ with open(neg_fasta_test, "w") as f: f.write(">neg1\nKALKALKALK\n>neg2\nPEPTPEPT\n")
303
+ if not os.path.exists(pos_fasta_test):
304
+ with open(pos_fasta_test, "w") as f: f.write(">pos1\nAOPPEPTIDE\n>pos2\nTRYTRYTRY\n")
305
+
306
+ # 为了让ProtT5Model能加载,需要模拟一个最小的transformers模型目录结构
307
+ # 通常至少需要 config.json, pytorch_model.bin (或 tf_model.h5), tokenizer_config.json, spiece.model
308
+ # 这里我们只创建目录,实际加载可能会失败,除非transformers库能从模型名下载
309
+ # 或者您提供一个真实的本地ProtT5模型路径
310
+ if not os.listdir(prott5_path_test): # 如果目录为空
311
+ print(f"警告: {prott5_path_test} 为空。ProtT5Model可能尝试从HuggingFace Hub下载模型。")
312
+ print(f"请确保您已下载Rostlab/ProstT5-XL-UniRef50或类似模型到该路径,或使用其HuggingFace名称。")
313
+ # 作为演示,我们假设用户会提供一个有效的路径或transformers可以处理它
314
+ # 如果要完全本地运行而不下载,需要填充该目录。
315
+
316
+ additional_params_test = {
317
+ "augment": False,
318
+ "perturb_rate": 0.1,
319
+ "finetuned_model_file": None # 测试时不使用微调模型
320
+ }
321
+
322
+ print("开始测试 prepare_features (使用RobustScaler)...")
323
+ try:
324
+ X_train_t, X_val_t, y_train_t, y_val_t, scaler_t = prepare_features(
325
+ neg_fasta_test,
326
+ pos_fasta_test,
327
+ "Rostlab/ProstT5-XL-UniRef50", # 使用HuggingFace模型名称,如果本地路径无效
328
+ additional_params_test
329
+ )
330
+ print("prepare_features 测试完成。")
331
+ print("训练集样本数:", X_train_t.shape[0])
332
+ print("验证集样本数:", X_val_t.shape[0])
333
+ if X_train_t.shape[0] > 0:
334
+ print("训练集特征维度:", X_train_t.shape[1])
335
+ print("一个缩放后的训练样本 (前5个特征):", X_train_t[0, :5])
336
+ if scaler_t:
337
+ print("Scaler类型:", type(scaler_t))
338
+ except Exception as e:
339
+ print(f"prepare_features 测试失败: {e}")
340
+ print("这可能是由于无法加载ProtT5模型或FASTA文件处理问题。请检查路径和文件内容。")
341
+
generator_checkpoints_v3.6/final_generator_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d878934b93514142df36cb04c812af25ad53a79bdfa0a3edd9f65c22ba4e1d48
3
+ size 75775354
prott5/model/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5ForConditionalGeneration"
4
+ ],
5
+ "d_ff": 16384,
6
+ "d_kv": 128,
7
+ "d_model": 1024,
8
+ "decoder_start_token_id": 0,
9
+ "dropout_rate": 0.1,
10
+ "eos_token_id": 1,
11
+ "feed_forward_proj": "relu",
12
+ "initializer_factor": 1.0,
13
+ "is_encoder_decoder": true,
14
+ "layer_norm_epsilon": 1e-06,
15
+ "model_type": "t5",
16
+ "n_positions": 512,
17
+ "num_decoder_layers": 24,
18
+ "num_heads": 32,
19
+ "num_layers": 24,
20
+ "output_past": true,
21
+ "pad_token_id": 0,
22
+ "relative_attention_num_buckets": 32,
23
+ "use_cache": true,
24
+ "vocab_size": 128
25
+ }
prott5/model/finetuned_prott5.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dbe43382d7c4aa0ba6de0ec02a9680155486a3e8984c82199ae8868223eba9e
3
+ size 75774969
prott5/model/gitattributes ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.pt filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
prott5/model/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
prott5/model/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74da7b4afcde53faa570114b530c726135bdfcdb813dec3abfb27f9d44db7324
3
+ size 237990
prott5/model/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": false}
requirements.txt ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py @ file:///home/conda/feedstock_root/build_artifacts/absl-py_1733730548347/work
2
+ accelerate==1.2.1
3
+ aiohappyeyeballs @ file:///home/conda/feedstock_root/build_artifacts/aiohappyeyeballs_1733331917983/work
4
+ aiohttp @ file:///home/conda/feedstock_root/build_artifacts/aiohttp_1734596887867/work
5
+ aiosignal @ file:///home/conda/feedstock_root/build_artifacts/aiosignal_1734342155601/work
6
+ alembic @ file:///home/conda/feedstock_root/build_artifacts/alembic_1733728406412/work
7
+ AmberUtils==21.0
8
+ astunparse @ file:///home/conda/feedstock_root/build_artifacts/astunparse_1728923142236/work
9
+ async-timeout @ file:///home/conda/feedstock_root/build_artifacts/async-timeout_1733235340728/work
10
+ attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1734348785146/work
11
+ biopython==1.79
12
+ blinker @ file:///home/conda/feedstock_root/build_artifacts/blinker_1731096409132/work
13
+ boto3==1.35.82
14
+ botocore==1.35.82
15
+ Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1725267488082/work
16
+ cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
17
+ cachetools @ file:///home/conda/feedstock_root/build_artifacts/cachetools_1733624221485/work
18
+ catboost @ https://pypi.org/packages/cp39/c/catboost/catboost-1.2.7-cp39-cp39-manylinux2014_x86_64.whl#sha256=e58cf8966e33931acebffbc744cf640e8abd08d0fdfb0e503c107552cea6c643
19
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1739515848642/work/certifi
20
+ cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725571112467/work
21
+ chardet @ file:///home/conda/feedstock_root/build_artifacts/chardet_1724954807150/work
22
+ charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1733218092148/work
23
+ click @ file:///home/conda/feedstock_root/build_artifacts/click_1733221831880/work
24
+ cloudpickle @ file:///home/conda/feedstock_root/build_artifacts/cloudpickle_1736947526808/work
25
+ cmake==3.31.2
26
+ colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
27
+ colorlog @ file:///home/conda/feedstock_root/build_artifacts/colorlog_1733258404285/work
28
+ contourpy @ file:///home/conda/feedstock_root/build_artifacts/contourpy_1727293517607/work
29
+ cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1732745941597/work
30
+ cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1733332471406/work
31
+ edgembar==3.0
32
+ filelock==3.16.1
33
+ Flask @ file:///home/conda/feedstock_root/build_artifacts/flask_1741793020411/work
34
+ flatbuffers @ file:///home/conda/feedstock_root/build_artifacts/python-flatbuffers_1733838640534/work
35
+ fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1733908950378/work
36
+ freetype-py @ file:///home/conda/feedstock_root/build_artifacts/freetype-py_1650983368720/work
37
+ frozenlist @ file:///home/conda/feedstock_root/build_artifacts/frozenlist_1729699456009/work
38
+ fsspec==2024.10.0
39
+ gast @ file:///home/conda/feedstock_root/build_artifacts/gast_1596839682936/work
40
+ google-auth @ file:///home/conda/feedstock_root/build_artifacts/google-auth_1734176111475/work
41
+ google-auth-oauthlib==1.0.0
42
+ google-pasta @ file:///home/conda/feedstock_root/build_artifacts/google-pasta_1733852499742/work
43
+ graphviz @ file:///home/conda/feedstock_root/build_artifacts/python-graphviz_1733791968395/work
44
+ greenlet @ file:///home/conda/feedstock_root/build_artifacts/greenlet_1734532792566/work
45
+ grpcio==1.68.1
46
+ gym @ file:///home/conda/feedstock_root/build_artifacts/gym_1673317490724/work
47
+ gym-notices @ file:///home/conda/feedstock_root/build_artifacts/gym-notices_1734709658541/work
48
+ h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1733298745555/work
49
+ h5py @ file:///home/conda/feedstock_root/build_artifacts/h5py_1734544937668/work
50
+ hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1733299205993/work
51
+ huggingface-hub==0.26.5
52
+ hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1733298771451/work
53
+ idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
54
+ importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1733223117029/work
55
+ importlib_resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1733231327252/work
56
+ itsdangerous @ file:///home/conda/feedstock_root/build_artifacts/itsdangerous_1733308265247/work
57
+ Jinja2==3.1.4
58
+ jmespath==1.0.1
59
+ joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1733736026804/work
60
+ keras @ file:///home/conda/feedstock_root/build_artifacts/keras_1669020828024/work/keras-2.11.0-py2.py3-none-any.whl
61
+ Keras-Preprocessing @ file:///home/conda/feedstock_root/build_artifacts/keras-preprocessing_1610713559828/work
62
+ kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1725459266648/work
63
+ lightgbm @ file:///home/conda/feedstock_root/build_artifacts/liblightgbm_1728547676427/work
64
+ lit==18.1.8
65
+ llvmlite==0.43.0
66
+ lmdb==1.5.1
67
+ Mako @ file:///home/conda/feedstock_root/build_artifacts/mako_1733628147227/work
68
+ Markdown==3.7
69
+ markdown-it-py @ file:///home/conda/feedstock_root/build_artifacts/markdown-it-py_1733250460757/work
70
+ MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1733219680183/work
71
+ matplotlib==3.9.4
72
+ mdurl @ file:///home/conda/feedstock_root/build_artifacts/mdurl_1733255585584/work
73
+ ml-dtypes @ file:///home/conda/feedstock_root/build_artifacts/ml_dtypes_1725475048616/work
74
+ MMPBSA.py==16.0
75
+ mordred==1.2.0
76
+ mpmath==1.3.0
77
+ multidict @ file:///home/conda/feedstock_root/build_artifacts/multidict_1733913043842/work
78
+ namex @ file:///home/conda/feedstock_root/build_artifacts/namex_1733858070416/work
79
+ ndfes==3.0
80
+ networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work
81
+ numba @ file:///home/conda/feedstock_root/build_artifacts/numba_1718888028049/work
82
+ numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1682210190296/work
83
+ nvidia-cublas-cu11==11.10.3.66
84
+ nvidia-cuda-cupti-cu11==11.7.101
85
+ nvidia-cuda-nvrtc-cu11==11.7.99
86
+ nvidia-cuda-runtime-cu11==11.7.99
87
+ nvidia-cudnn-cu11==8.5.0.96
88
+ nvidia-cufft-cu11==10.9.0.58
89
+ nvidia-curand-cu11==10.2.10.91
90
+ nvidia-cusolver-cu11==11.4.0.1
91
+ nvidia-cusparse-cu11==11.7.4.91
92
+ nvidia-nccl-cu11==2.14.3
93
+ nvidia-nvtx-cu11==11.7.91
94
+ oauthlib @ file:///home/conda/feedstock_root/build_artifacts/oauthlib_1733752848439/work
95
+ opt_einsum @ file:///home/conda/feedstock_root/build_artifacts/opt_einsum_1733687912731/work
96
+ optree @ file:///home/conda/feedstock_root/build_artifacts/optree_1731510511426/work
97
+ optuna @ file:///home/conda/feedstock_root/build_artifacts/optuna_1734927470522/work
98
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1733203243479/work
99
+ packmol_memgen==2025.1.29
100
+ pandas @ file:///home/conda/feedstock_root/build_artifacts/pandas_1726878406085/work
101
+ patsy @ file:///home/conda/feedstock_root/build_artifacts/patsy_1733792384640/work
102
+ pdb4amber==22.0
103
+ peft==0.4.0
104
+ peptides @ file:///opt/conda/conda-bld/peptides_1726747256415/work
105
+ pfeature==1.4
106
+ Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1688255839723/work
107
+ plotly @ file:///home/conda/feedstock_root/build_artifacts/plotly_1733733072265/work
108
+ propcache @ file:///home/conda/feedstock_root/build_artifacts/propcache_1733391807885/work
109
+ protobuf==5.29.1
110
+ psutil==6.1.0
111
+ pyasn1 @ file:///home/conda/feedstock_root/build_artifacts/pyasn1_1733217608156/work
112
+ pyasn1_modules @ file:///home/conda/feedstock_root/build_artifacts/pyasn1-modules_1733324602540/work
113
+ pycairo==1.25.0
114
+ pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
115
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1733221634316/work
116
+ PyJWT @ file:///home/conda/feedstock_root/build_artifacts/pyjwt_1732782409051/work
117
+ pyMSMT==22.0
118
+ pynndescent @ file:///home/conda/feedstock_root/build_artifacts/pynndescent_1734193530377/work
119
+ pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1733777368795/work
120
+ pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1733222594562/work
121
+ PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
122
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
123
+ pytraj==2.0.6
124
+ pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1706886791323/work
125
+ pyu2f @ file:///home/conda/feedstock_root/build_artifacts/pyu2f_1733738580568/work
126
+ PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1725456176299/work
127
+ regex==2024.11.6
128
+ reportlab @ file:///home/conda/feedstock_root/build_artifacts/reportlab_1727846838921/work
129
+ requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1733217035951/work
130
+ requests-oauthlib @ file:///home/conda/feedstock_root/build_artifacts/requests-oauthlib_1733772243268/work
131
+ rich @ file:///home/conda/feedstock_root/build_artifacts/rich_1733342254348/work/dist
132
+ rlPyCairo @ file:///home/conda/feedstock_root/build_artifacts/rlpycairo_1687519531733/work
133
+ rsa @ file:///home/conda/feedstock_root/build_artifacts/rsa_1733662684165/work
134
+ s3transfer==0.10.4
135
+ safetensors==0.4.5
136
+ sander==22.0
137
+ scikit-learn @ file:///home/conda/feedstock_root/build_artifacts/scikit-learn_1733760110140/work/dist/scikit_learn-1.6.0-cp39-cp39-linux_x86_64.whl#sha256=f4205b3762b8033782da7a40727c959075158f152dcb3f50e33f60ecfc8cf6f4
138
+ scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy-split_1716470218293/work/dist/scipy-1.13.1-cp39-cp39-linux_x86_64.whl#sha256=e6696cb8683d94467891b7648e068a3970f6bc0a1b3c1aa7f9bc89458eafd2f0
139
+ seaborn @ file:///home/conda/feedstock_root/build_artifacts/seaborn-split_1733730015268/work
140
+ sentencepiece @ file:///home/conda/feedstock_root/build_artifacts/sentencepiece-split_1674200849009/work/python
141
+ shap @ file:///home/conda/feedstock_root/build_artifacts/shap_1732083651745/work
142
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
143
+ slicer @ file:///home/conda/feedstock_root/build_artifacts/slicer_1710029110974/work
144
+ SQLAlchemy @ file:///home/conda/feedstock_root/build_artifacts/sqlalchemy_1729066347171/work
145
+ statsmodels @ file:///home/conda/feedstock_root/build_artifacts/statsmodels_1727986706423/work
146
+ sympy==1.13.3
147
+ tape-proteins==0.5
148
+ tenacity @ file:///home/conda/feedstock_root/build_artifacts/tenacity_1733649050774/work
149
+ tensorboard==2.13.0
150
+ tensorboard-data-server==0.7.2
151
+ tensorboard-plugin-wit @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-plugin-wit_1641458951060/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
152
+ tensorboardX==2.6.2.2
153
+ tensorflow @ file:///home/conda/feedstock_root/build_artifacts/tensorflow-split_1679857771045/work/tensorflow_pkg/tensorflow-2.11.1-cp39-cp39-linux_x86_64.whl
154
+ tensorflow-estimator @ file:///home/conda/feedstock_root/build_artifacts/tensorflow-split_1679857771045/work/tensorflow-estimator/wheel_dir/tensorflow_estimator-2.11.0-py2.py3-none-any.whl
155
+ termcolor @ file:///home/conda/feedstock_root/build_artifacts/termcolor_1733754631889/work
156
+ threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1714400101435/work
157
+ tokenizers==0.13.3
158
+ torch==2.0.1
159
+ tqdm==4.65.0
160
+ transformers==4.30.0
161
+ triton==2.0.0
162
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1733188668063/work
163
+ tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1733235305708/work
164
+ umap-learn @ file:///home/conda/feedstock_root/build_artifacts/umap-learn_1730221913988/work
165
+ unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1729704563364/work
166
+ urllib3==1.26.20
167
+ Werkzeug @ file:///home/conda/feedstock_root/build_artifacts/werkzeug_1733160440960/work
168
+ wrapt @ file:///home/conda/feedstock_root/build_artifacts/wrapt_1732523603052/work
169
+ xgboost @ file:///home/conda/feedstock_root/build_artifacts/xgboost-split_1733179637554/work/python-package
170
+ yarl @ file:///home/conda/feedstock_root/build_artifacts/yarl_1733428811664/work
171
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
172
+ zstandard==0.23.0