Clustering Protein Complexes using Persistent Homology and Finetuning ESM-2 for PPI Network Prediction
Clustering protein complexes is at present something of an open problem. In this article we will have a look at a new method for clustering protein sequences and protein complexes which is based on protein language model embeddings and persistent homology. We will then show how to finetune ESM-2 to predict protein-protein interactions using a train/test split based on this new clustering method.
Protein Complexes
Proteins, the workhorses of the cell, often do not act in isolation. Many vital biological functions are carried out by protein complexes, structures formed by the association of two or more protein molecules. The formation of these complexes is a highly regulated and specific process, essential for many cellular activities, including signaling pathways, metabolic reactions, DNA replication, and more. Understanding protein complexes is critical for grasping cellular mechanisms at a molecular level. Here we will focus on protein complexes composed of two proteins, but this method can be extended to complexes with more than two proteins.
Sequence Similarity
Measuring the similarity between two protein sequences and protein homology modeling can be done in various ways. One of the oldest is based on the edit distance, but there are newer methods that employ the use of the embeddings of protein language models to determine how similar two proteins are. For example in the paper Protein Language Model Performs Efficient Homology Detection, the authors devise a way to do similarity search using the embeddings from the somewhat older ESM-1b model. We draw inspiration from these methods here, and use the newer protein language model ESM-2 in our method. We also use a technique from topological data analysis to obtain a topological summary of the embeddings associated to proteins and protein complexes. While this method can be used for individual proteins, we are more concerned with how it applies to protein complexes, as standard sequence similarity measures fall short or become overly complicated when attempting to determine the similarity of two protein complexes. In fact, most methods for homology modeling of protein-protein complexes are structurally based methods. We will attempt to remedy this with a method known as persistent homology applied to the embeddings of a protein complex using a protein language model, ESM-2.
Clustering Protein Complexes
Clustering protein complexes is a difficult problem that we will approach in a novel way using persistent homology. We first concatenate pairs of interacting proteins obtained from the UniProt database, focusing on human proteins for now. Next, we compute the embeddings associated to the concatenated pair by the pLM ESM-2. Once we have our embeddings, we compute something called a persistence diagram using persistent homology. Each protein complex is given such a diagram, and the pairwise distances between these diagrams is computed using the Wasserstein distance metric on persistence diagrams. Next, we use the distance matrix we obtain to compute a second level persistence diagram that summarizes the Wasserstein distances between persistence diagrams associated to each protein-protein complex. Next, we run a DBSCAN based on this second level persistence diagram, choosing an epsilon threshold such that 80% of the points in the (second level) zero dimensional peristence diagram fall below epsilon. This will return clusters of protein-protein complexes which we can then use to create a train/test split for training a model to predict protein-protein interactions or for generating binders for a target protein.
Clustering Script
Below, we provide the script for clustering the protein-protein complexes based on persistent homology. To use it you can download the dataset of interacting proteins from HuggingFace here.
import pandas as pd
import numpy as np
from transformers import EsmModel, AutoTokenizer
import torch
from scipy.spatial.distance import pdist, squareform
from gudhi import RipsComplex
from gudhi.hera import wasserstein_distance
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
from tqdm import tqdm
# Define a helper function for hidden states
def get_hidden_states(sequence, tokenizer, model, layer):
model.config.output_hidden_states = True
encoded_input = tokenizer([sequence], return_tensors='pt', padding=True, truncation=True, max_length=1024)
with torch.no_grad():
model_output = model(**encoded_input)
hidden_states = model_output.hidden_states
specific_hidden_states = hidden_states[layer][0]
return specific_hidden_states.numpy()
# Define a helper function for Euclidean distance matrix
def compute_euclidean_distance_matrix(hidden_states):
euclidean_distances = pdist(hidden_states, metric='euclidean')
euclidean_distance_matrix = squareform(euclidean_distances)
return euclidean_distance_matrix
# Define a helper function for persistent homology
def compute_persistent_homology(distance_matrix, max_dimension=0):
max_edge_length = np.max(distance_matrix)
rips_complex = RipsComplex(distance_matrix=distance_matrix, max_edge_length=max_edge_length)
st = rips_complex.create_simplex_tree(max_dimension=max_dimension)
st.persistence()
return st, st.persistence()
# Define a helper function for Wasserstein distances
def compute_wasserstein_distances(persistence_diagrams, dimension):
n_diagrams = len(persistence_diagrams)
distances = np.zeros((n_diagrams, n_diagrams))
for i in tqdm(range(n_diagrams), desc="Computing Wasserstein Distances"):
for j in range(i+1, n_diagrams):
X = np.array([p[1] for p in persistence_diagrams[i] if p[0] == dimension])
Y = np.array([p[1] for p in persistence_diagrams[j] if p[0] == dimension])
distance = wasserstein_distance(X, Y)
distances[i][j] = distance
distances[j][i] = distance
return distances
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
# Define layer to be used
layer = model.config.num_hidden_layers - 1
# Load the TSV file
file_path = 'pepmlm/scripts/data/filtered_protein_interaction_pairs.tsv'
protein_pairs_df = pd.read_csv(file_path, sep='\t')
# Only process the first 1000 proteins
protein_pairs_df = protein_pairs_df.head(1000)
# Extract concatenated sequences
concatenated_sequences = protein_pairs_df['Protein1'] + protein_pairs_df['Protein2']
# Initialize list to store persistent diagrams
persistent_diagrams = []
# Loop over concatenated sequences to compute their persistent diagrams
for sequence in tqdm(concatenated_sequences, desc="Computing Persistence Diagrams"):
hidden_states_matrix = get_hidden_states(sequence, tokenizer, model, layer)
distance_matrix = compute_euclidean_distance_matrix(hidden_states_matrix)
_, persistence_diagram = compute_persistent_homology(distance_matrix)
persistent_diagrams.append(persistence_diagram)
# Compute the Wasserstein distances
wasserstein_distances = compute_wasserstein_distances(persistent_diagrams, 0)
# Compute the second-level persistent homology
with tqdm(total=1, desc="Computing Second-Level Persistent Homology") as pbar:
st_2, persistence_2 = compute_persistent_homology(wasserstein_distances)
pbar.update(1)
# Function to calculate the epsilon for DBSCAN
def calculate_epsilon(persistence_diagrams, threshold_percentage):
lifetimes = [p[1][1] - p[1][0] for p in persistence_diagrams if p[0] == 0]
lifetimes.sort()
threshold_index = int(threshold_percentage * len(lifetimes))
return lifetimes[threshold_index]
# Calculate epsilon
threshold_percentage = 0.8 # 80%
epsilon = calculate_epsilon(persistence_2, threshold_percentage)
# Perform DBSCAN clustering
with tqdm(total=1, desc="Performing DBSCAN Clustering") as pbar:
dbscan = DBSCAN(metric="precomputed", eps=epsilon, min_samples=1)
dbscan.fit(wasserstein_distances)
labels = dbscan.labels_
pbar.update(1)
# Add the cluster labels to the DataFrame
protein_pairs_df['Cluster'] = labels
# Save the DataFrame with cluster information
output_file_path = 'clustered_protein_interaction_pairs.tsv'
protein_pairs_df.to_csv(output_file_path, sep='\t', index=False)
print(f"Clustered data saved to: {output_file_path}")
In this script, we have the following interpretation and use of persistent homology to cluster the protein-protein complexes.
Persistence Diagrams and Lifetimes: In the script, persistence diagrams are generated for each concatenated protein sequence using topological data analysis (TDA) techniques. These diagrams capture the "birth" and "death" of topological features (like clusters or loops) at various scales. The "lifetime" of a feature is the difference between its death and birth values.
Second-Level Persistence Diagram: After computing Wasserstein distances between all pairs of persistence diagrams, a second-level persistence diagram is generated. This diagram represents the persistence of features (clusters, in this case) in the space of Wasserstein distances.
Calculating Lifetimes: For each point in the zero-dimensional part of the second-level persistence diagram, the lifetime is calculated. These lifetimes represent the persistence of clusters in the Wasserstein distance space.
Determining the 80% Threshold: The lifetimes are sorted, and the 80th percentile value is identified. This means that 80% of the points in the zero-dimensional persistence diagram have a lifetime less than or equal to this value.
Setting ε in DBSCAN: The 80% threshold value is used as the
epsilon
parameter in DBSCAN. In DBSCAN,epsilon
determines the maximum distance between two samples for one to be considered as in the neighborhood of the other. By settingepsilon
to this 80% threshold, the clustering algorithm is tailored to capture the majority of the natural clustering structure as revealed by the topological analysis.Interpretation: Using the 80% threshold to set
epsilon
means that the DBSCAN clustering is sensitive to the most persistent and significant features in the data, as captured by the second-level persistence diagram. This approach aims to identify clusters that are robust and significant in the context of the data's topological structure.
Due to how computationally prohibitive it is to cluster protein-protein complexes in this way, we will only cluster the first 1000 protein sequences. Finding ways to optimize this clustering method would allow for clustering many more protein-protein complexes. As an alternative, we might follow the methodology mentioned in Protein Language Model Performs Efficient Homology Detection. This might also allow us to improve the methodology of vcMSA, mentioned in this blog post. We might also use GPU computation of persistent homology to speed things up.
Next, we create a train/test split based on these clusters, which will play a role similar to clustering sequences based on sequence similarity, or creating a train/test split based on protein families in the UniProt database. However, this form of clustering is based on objects, obtained from the embedding vectors called a filtered simplicial complex, which are a kind of geometric fingerprint for the protein-protein complex. They are a topological summary of the semantic information encoded in the embedding vectors (hidden states of ESM-2).
Finetuning ESM-2 on Pairs of Interacting Proteins
Next, we can finetune ESM-2 on pairs of interacting proteins from UniProt to improve its capabilities for modeling protein complexes, predicting protein-protein interactions, or generating binders. We do so for a single epoch on human the clustered human proteins obtained in the previous section by running the following script. This script will finetune ESM-2 to predict PPI networks based on the MLM loss, as mentioned in this blog post. We will cover how to finetune ESM-2 to generate binders in a later post, which will be similar to the methods used for finetuning ESM-2 to get PepMLM, as mentioned in this blog post.
from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, TrainerCallback, get_scheduler
from torch.utils.data import Dataset
import pandas as pd
import torch
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
import random
class ProteinDataset(Dataset):
def __init__(self, proteins, peptides, tokenizer):
self.tokenizer = tokenizer
self.proteins = proteins
self.peptides = peptides
def __len__(self):
return len(self.proteins)
def mask_sequence(self, sequence, mask_percentage):
mask_indices = random.sample(range(len(sequence)), int(len(sequence) * mask_percentage))
return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)])
def __getitem__(self, idx):
protein_seq = self.proteins[idx]
peptide_seq = self.peptides[idx]
masked_protein = self.mask_sequence(protein_seq, 0.55)
masked_peptide = self.mask_sequence(peptide_seq, 0.55)
complex_seq = masked_protein + masked_peptide
# Tokenize and pad the complex sequence
complex_input = self.tokenizer(
complex_seq,
return_tensors="pt",
padding="max_length",
max_length=1024,
truncation=True,
add_special_tokens=False
)
input_ids = complex_input["input_ids"].squeeze()
attention_mask = complex_input["attention_mask"].squeeze()
# Create labels
label_seq = protein_seq + peptide_seq
labels = self.tokenizer(
label_seq,
return_tensors="pt",
padding="max_length",
max_length=1024,
truncation=True,
add_special_tokens=False
)["input_ids"].squeeze()
# Set non-masked positions in the labels tensor to -100
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
# Loading the dataset
file_path = "clustered_protein_interaction_pairs.tsv"
data = pd.read_csv(file_path, delimiter='\t')
# Splitting the data based on clusters
cluster_sizes = data['Cluster'].value_counts(normalize=True)
test_clusters = []
test_size = 0
for cluster, size in cluster_sizes.items():
test_clusters.append(cluster)
test_size += size
if test_size >= 0.20:
break
test_data = data[data['Cluster'].isin(test_clusters)]
train_data = data[~data['Cluster'].isin(test_clusters)]
proteins_train = train_data["Protein1"].tolist()
peptides_train = train_data["Protein2"].tolist()
proteins_test = test_data["Protein1"].tolist()
peptides_test = test_data["Protein2"].tolist()
# Load tokenizer and model
model_name = "esm2_t30_150M_UR50D"
tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name)
model = EsmForMaskedLM.from_pretrained("facebook/" + model_name)
# Training arguments
training_args = TrainingArguments(
output_dir='./interact_output/',
num_train_epochs=3,
per_device_train_batch_size=1,
per_device_eval_batch_size=4,
warmup_steps=50,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
load_best_model_at_end=True,
save_strategy='epoch',
metric_for_best_model='eval_loss',
save_total_limit=3,
gradient_accumulation_steps=2,
lr_scheduler_type='cosine' # Set learning rate scheduler to cosine
)
# Optimizer
optimizer = AdamW(model.parameters(), lr=0.0007984276816171436)
# Scheduler
scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=training_args.max_steps
)
# Instantiate the ProteinDataset for training and testing
train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer)
test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
optimizers=(optimizer, scheduler),
)
# Start training
trainer.train()
Conclusion
Now that you have your finetuned ESM-2 model for predicting protein-protein interactions, you can follow along with this blog post to predict PPI networks. You can also try adjusting the training script to match the training script of PepMLM so that the binder is fully masked and the target sequence is left unmasked in order to finetune ESM-2 to generate binders for target proteins as explained in this post, or you can just wait for the next blog post which will likely cover this! While this method is novel and interesting from a mathematical perspective, clustering with persistent homology on a CPU appears to be computationally prohibitive. Perhaps as a next project you can work on implementing the persistent homology computations on a GPU using this to speed up the computations!