ESM-2 for RNA Binding Site Prediction
A small RNA binding site predictor trained on dataset "S1" from Data of protein-RNA binding sites using facebook/esm2_t6_8M_UR50D. The dataset can also be found on Hugging Face here.
This model only has a validation loss of 0.12738210861297214
.
To use, try running:
import torch
from transformers import AutoTokenizer, EsmForTokenClassification
# Define the class mapping
class_mapping = {
0: 'Not Binding Site',
1: 'Binding Site',
}
# Load the trained model and tokenizer
model = EsmForTokenClassification.from_pretrained("AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
# Define the new sequences
new_sequences = [
'VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK',
'SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF',
# ... add more sequences here ...
]
# Iterate over the new sequences
for seq in new_sequences:
# Convert sequence to input IDs
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")["input_ids"]
# Apply the model to get the logits
with torch.no_grad():
outputs = model(inputs)
# Get the predictions by picking the label (class) with the highest logit
predictions = torch.argmax(outputs.logits, dim=-1)
# Convert the tensor to a list of integers
prediction_list = predictions.tolist()[0]
# Convert the predicted class indices to class names
predicted_labels = [class_mapping[pred] for pred in prediction_list]
# Create a list that matches each amino acid in the sequence to its predicted class label
residue_to_label = list(zip(list(seq), predicted_labels))
# Print out the list
for i, (residue, predicted_label) in enumerate(residue_to_label):
print(f"Position {i+1} - {residue}: {predicted_label}")
- Downloads last month
- 9
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.