lhallee commited on
Commit
0a52e5e
·
verified ·
1 Parent(s): 5aba20c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +136 -5
README.md CHANGED
@@ -56,11 +56,9 @@ import torch
56
  import torch.nn.functional as F
57
  from transformers import BertForSequenceClassification, BertTokenizer
58
 
59
- model = BertForSequenceClassification.from_pretrained('GleghornLab/SYNTERACT') # load model
 
60
  tokenizer = BertTokenizer.from_pretrained('GleghornLab/SYNTERACT') # load tokenizer
61
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # gather device
62
- model.to(device) # move to device
63
- model.eval() # put in eval mode
64
 
65
  sequence_a = 'MEKSCSIGNGREQYGWGHGEQCGTQFLECVYRNASMYSVLGDLITYVVFLGATCYAILFGFRLLLSCVRIVLKVVIALFVIRLLLALGSVDITSVSYSG' # Uniprot A1Z8T3
66
  sequence_b = 'MRLTLLALIGVLCLACAYALDDSENNDQVVGLLDVADQGANHANDGAREARQLGGWGGGWGGRGGWGGRGGWGGRGGWGGRGGWGGGWGGRGGWGGRGGGWYGR' # Uniprot A1Z8H0
@@ -70,7 +68,7 @@ example = sequence_a + ' [SEP] ' + sequence_b # add SEP token
70
 
71
  example = tokenizer(example, return_tensors='pt', padding=False).to(device) # tokenize example
72
  with torch.no_grad():
73
- logits = model(**example).logits.cpu().detach() # get logits from model
74
 
75
  probability = F.softmax(logits, dim=-1) # use softmax to get "confidence" in the prediction
76
  prediction = probability.argmax(dim=-1) # 0 for no interaction, 1 for interaction
@@ -92,4 +90,137 @@ The [Gleghorn lab](https://www.gleghornlab.com/) is an interdisciplinary researc
92
  publisher = {Cold Spring Harbor Laboratory},
93
  journal = {bioRxiv}
94
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  ```
 
56
  import torch.nn.functional as F
57
  from transformers import BertForSequenceClassification, BertTokenizer
58
 
59
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # gather device
60
+ model = BertForSequenceClassification.from_pretrained('GleghornLab/SYNTERACT', attn_implementation='sdpa').device.eval() # load model
61
  tokenizer = BertTokenizer.from_pretrained('GleghornLab/SYNTERACT') # load tokenizer
 
 
 
62
 
63
  sequence_a = 'MEKSCSIGNGREQYGWGHGEQCGTQFLECVYRNASMYSVLGDLITYVVFLGATCYAILFGFRLLLSCVRIVLKVVIALFVIRLLLALGSVDITSVSYSG' # Uniprot A1Z8T3
64
  sequence_b = 'MRLTLLALIGVLCLACAYALDDSENNDQVVGLLDVADQGANHANDGAREARQLGGWGGGWGGRGGWGGRGGWGGRGGWGGRGGWGGGWGGRGGWGGRGGGWYGR' # Uniprot A1Z8H0
 
68
 
69
  example = tokenizer(example, return_tensors='pt', padding=False).to(device) # tokenize example
70
  with torch.no_grad():
71
+ logits = model(**example).logits.detach().cpu() # get logits from model
72
 
73
  probability = F.softmax(logits, dim=-1) # use softmax to get "confidence" in the prediction
74
  prediction = probability.argmax(dim=-1) # 0 for no interaction, 1 for interaction
 
90
  publisher = {Cold Spring Harbor Laboratory},
91
  journal = {bioRxiv}
92
  }
93
+ ```
94
+
95
+
96
+ ## A simple inference script
97
+
98
+ ```python
99
+ import torch
100
+ import re
101
+ import argparse
102
+ import pandas as pd
103
+ from transformers import BertForSequenceClassification, BertTokenizer
104
+ from torch.utils.data import Dataset, DataLoader
105
+ from typing import List, Tuple, Dict
106
+ from tqdm.auto import tqdm
107
+
108
+
109
+ class PairDataset(Dataset):
110
+ def __init__(self, sequences_a: List[str], sequences_b: List[str]):
111
+ self.sequences_a = sequences_a
112
+ self.sequences_b = sequences_b
113
+
114
+ def __len__(self):
115
+ return len(self.sequences_a)
116
+
117
+ def __getitem__(self, idx: int) -> Tuple[str, str]:
118
+ return self.sequences_a[idx], self.sequences_b[idx]
119
+
120
+
121
+ class PairCollator:
122
+ def __init__(self, tokenizer, max_length=1024):
123
+ self.tokenizer = tokenizer
124
+ self.max_length = max_length
125
+
126
+ def sanitize_seq(self, seq: str) -> str:
127
+ seq = ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))
128
+ return seq
129
+
130
+ def __call__(self, batch: List[Tuple[str, str]]) -> Dict[str, torch.Tensor]:
131
+ seqs_a, seqs_b, = zip(*batch)
132
+ seqs = []
133
+ for a, b in zip(seqs_a, seqs_b):
134
+ seq = self.sanitize_seq(a) + ' [SEP] ' + self.sanitize_seq(b)
135
+ seqs.append(seq)
136
+ seqs = self.tokenizer(seqs, padding='longest', truncation=True, max_length=self.max_length, return_tensors='pt')
137
+ return {
138
+ 'input_ids': seqs['input_ids'],
139
+ 'attention_mask': seqs['attention_mask'],
140
+ }
141
+
142
+
143
+ def main(args):
144
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
145
+ print(f"Using device: {device}")
146
+ print(f"Loading model from {args.model_path}")
147
+ model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa").eval().to(device)
148
+ # When using PyTorch >= 2.5.1 on a linux machine, spda attention will greatly speed up inference
149
+ tokenizer = BertTokenizer.from_pretrained(args.model_path)
150
+ print(f"Tokenizer loaded")
151
+
152
+ """
153
+ Load your data into two lists of sequences, where you want the PPI for each pair sequences_a[i], sequences_b[i]
154
+ We recommend trimmed sequence pairs that sum over 1022 tokens (for the 1024 max length limit of SYNTERACT)
155
+ We also recommend sorting the sequences by length in descending order, as this will speed up inference by reducing padding
156
+
157
+ Example:
158
+ from datasets import load_dataset
159
+ data = load_dataset('Synthyra/NEGATOME', split='combined')
160
+ # Filter out examples where the total length exceeds 1022
161
+ data = data.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= 1022)
162
+ # Add a new column 'total_length' that is the sum of lengths of SeqA and SeqB
163
+ data = data.map(lambda x: {"total_length": len(x['SeqA']) + len(x['SeqB'])})
164
+ # Sort the dataset by 'total_length' in descending order (longest sequences first)
165
+ data = data.sort("total_length", reverse=True)
166
+ # Now retrieve the sorted sequences
167
+ sequences_a = data['SeqA']
168
+ sequences_b = data['SeqB']
169
+ """
170
+ print("Loading data...")
171
+ sequences_a = []
172
+ sequences_b = []
173
+
174
+ print("Creating torch dataset...")
175
+ pair_dataset = PairDataset(sequences_a, sequences_b)
176
+ pair_collator = PairCollator(tokenizer, max_length=1024)
177
+ data_loader = DataLoader(pair_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=pair_collator)
178
+
179
+ all_seqs_a = []
180
+ all_seqs_b = []
181
+ all_probs = []
182
+ all_preds = []
183
+
184
+ print("Starting inference...")
185
+ with torch.no_grad():
186
+ for i, batch in enumerate(tqdm(data_loader, total=len(data_loader), desc="Batches processed")):
187
+ # Because sequences are sorted, the initial estimate for time will be much longer than the actual time it will take
188
+ input_ids = batch['input_ids'].to(device)
189
+ attention_mask = batch['attention_mask'].to(device)
190
+ logits = model(input_ids, attention_mask=attention_mask).logits.detach().cpu()
191
+
192
+ prob_of_interaction = torch.softmax(logits, dim=1)[:, 1] # can do 1 - this for no interaction prob
193
+ pred = torch.argmax(logits, dim=1)
194
+
195
+ # Store results
196
+ batch_start = i * args.batch_size
197
+ batch_end = min((i + 1) * args.batch_size, len(sequences_a))
198
+ all_seqs_a.extend(sequences_a[batch_start:batch_end])
199
+ all_seqs_b.extend(sequences_b[batch_start:batch_end])
200
+ all_probs.extend(prob_of_interaction.tolist())
201
+ all_preds.extend(pred.tolist())
202
+
203
+ # round to 5 decimal places
204
+ all_probs = [round(prob, 5) for prob in all_probs]
205
+
206
+ # Create dataframe and save to CSV
207
+ results_df = pd.DataFrame({
208
+ 'sequence_a': all_seqs_a,
209
+ 'sequence_b': all_seqs_b,
210
+ 'probabilities': all_probs,
211
+ 'prediction': all_preds
212
+ })
213
+ print(f"Saving results to {args.save_path}")
214
+ results_df.to_csv(args.save_path, index=False)
215
+
216
+
217
+ if __name__ == '__main__':
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument('--model_path', type=str, default='GleghornLab/SYNTERACT')
220
+ parser.add_argument('--save_path', type=str, default='ppi_predictions.csv')
221
+ parser.add_argument('--batch_size', type=int, default=2)
222
+ parser.add_argument('--num_workers', type=int, default=0) # can increase to use multiprocessing for dataloader, 4 is a good value usually
223
+ args = parser.parse_args()
224
+
225
+ main(args)
226
  ```