"Open

#AlphaFold - single sequence input
- WARNING - For DEMO and educational purposes only. 
- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could potentially be useful for evaluating *de novo* designed proteins.


In [None]:
#@title Setup
from IPython.utils import io
import os,sys,re
import tensorflow as tf
import jax
import jax.numpy as jnp
import numpy as np

with io.capture_output() as captured:
 if not os.path.isdir("af_backprop"):
 %shell git clone -b beta https://github.com/sokrypton/af_backprop.git
 %shell pip -q install biopython dm-haiku ml-collections py3Dmol
 %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py
 if not os.path.isdir("params"):
 %shell mkdir params
 %shell curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params

try:
 # check if TPU is available
 import jax.tools.colab_tpu
 jax.tools.colab_tpu.setup_tpu()
 print('Running on TPU')
 DEVICE = "tpu"
except:
 if jax.local_devices()[0].platform == 'cpu':
 print("WARNING: no GPU detected, will be using CPU")
 DEVICE = "cpu"
 else:
 print('Running on GPU')
 DEVICE = "gpu"
 # disable GPU on tensorflow
 tf.config.set_visible_devices([], 'GPU')

sys.path.append('/content/af_backprop')
# import libraries
from utils import update_seq, update_aatype, get_plddt, get_pae
import colabfold as cf
from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.model import data, config, model
from alphafold.common import residue_constants

def clear_mem():
 backend = jax.lib.xla_bridge.get_backend()
 for buf in backend.live_buffers(): buf.delete()

def setup_model(max_len, model_name="model_2_ptm"):

 clear_mem()

 # setup model
 cfg = config.model_config("model_5_ptm")
 cfg.model.num_recycle = 0
 cfg.data.common.num_recycle = 0
 cfg.data.eval.max_msa_clusters = 1
 cfg.data.common.max_extra_msa = 1
 cfg.data.eval.masked_msa_replace_fraction = 0
 cfg.model.global_config.subbatch_size = None
 model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
 model_runner = model.RunModel(cfg, model_params, is_training=False)

 seq = "A" * max_len
 length = len(seq)
 feature_dict = {
 **pipeline.make_sequence_features(sequence=seq, description="none", num_res=length),
 **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0]*length]])
 }
 inputs = model_runner.process_features(feature_dict,random_seed=0)

 def runner(seq, opt):
 # update sequence
 inputs = opt["inputs"]
 inputs.update(opt["prev"])
 update_seq(seq, inputs)
 update_aatype(inputs["target_feat"][...,1:], inputs)

 # mask prediction
 mask = seq.sum(-1)
 inputs["seq_mask"] = inputs["seq_mask"].at[:].set(mask)
 inputs["msa_mask"] = inputs["msa_mask"].at[:].set(mask)
 inputs["residue_index"] = jnp.where(mask==1,inputs["residue_index"],0)

 # get prediction
 key = jax.random.PRNGKey(0)
 outputs = model_runner.apply(opt["params"], key, inputs)

 prev = {"init_msa_first_row":outputs['representations']['msa_first_row'][None],
 "init_pair":outputs['representations']['pair'][None],
 "init_pos":outputs['structure_module']['final_atom_positions'][None]}
 
 aux = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
 "final_atom_mask":outputs["structure_module"]["final_atom_mask"],
 "plddt":get_plddt(outputs),"pae":get_pae(outputs),
 "inputs":inputs, "prev":prev}
 return aux

 return jax.jit(runner), {"inputs":inputs,"params":model_params}

MAX_LEN = 50
RUNNER, OPT = setup_model(MAX_LEN)

In [None]:
%%time
#@title Enter the amino acid sequence to fold ⬇️

sequence = 'GGGGGGGGGGGGGGGGGGGG' #@param {type:"string"}
recycles = 0 #@param ["0", "1", "2", "3", "6", "12", "24"] {type:"raw"}
SEQ = re.sub("[^A-Z]", "", sequence.upper())
LEN = len(SEQ)
if LEN > MAX_LEN:
 print("recompiling...")
 MAX_LEN = LEN
 RUNNER, OPT = setup_model(MAX_LEN)

x = np.array([residue_constants.restype_order.get(aa,0) for aa in SEQ])
x = np.pad(x,[0,MAX_LEN-LEN],constant_values=-1)
x = jax.nn.one_hot(x,20)

OPT["prev"] = {'init_msa_first_row': np.zeros([1, MAX_LEN, 256]),
 'init_pair': np.zeros([1, MAX_LEN, MAX_LEN, 128]),
 'init_pos': np.zeros([1, MAX_LEN, 37, 3])}

positions = []
plddts = []
for r in range(recycles+1):
 outs = RUNNER(x, OPT)
 outs = jax.tree_map(lambda x:np.asarray(x), outs)
 positions.append(outs["prev"]["init_pos"][0,:LEN])
 plddts.append(outs["plddt"][:LEN])
 OPT["prev"] = outs["prev"]
 if recycles > 0:
 print(r, plddts[-1].mean())

In [None]:
#@title Display 3D structure {run: "auto"}
color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = True #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
#@markdown - TIP - hold mouse over aminoacid to get name and position number

def save_pdb(outs, filename):
 '''save pdb coordinates'''
 p = {"residue_index":outs["inputs"]["residue_index"][0][:LEN] + 1,
 "aatype":outs["inputs"]["aatype"].argmax(-1)[0][:LEN],
 "atom_positions":outs["final_atom_positions"][:LEN],
 "atom_mask":outs["final_atom_mask"][:LEN]}
 b_factors = 100.0 * outs["plddt"][:LEN,None] * p["atom_mask"]
 p = protein.Protein(**p,b_factors=b_factors)
 pdb_lines = protein.to_pdb(p)
 with open(filename, 'w') as f:
 f.write(pdb_lines)

save_pdb(outs,"out.pdb")
num_res = int(outs["inputs"]["aatype"][0].sum())

v = cf.show_pdb("out.pdb", show_sidechains, show_mainchains, color,
 color_HP=True, size=(800,480)) 
v.setHoverable({},
 True,
 '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(" "+atom.resn+":"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',
 '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')
v.show() 

if color == "lDDT":
 cf.plot_plddt_legend().show() 
if "pae" in outs:
 cf.plot_confidence(outs["plddt"][:LEN]*100, outs["pae"][:LEN,:LEN]).show()
else:
 cf.plot_confidence(outs["plddt"][:LEN]*100).show()

In [None]:
#@title Animate
#@markdown - Animate trajectory if more than 0 recycle(s)
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from IPython.display import HTML

def make_animation(positions, plddts=None, line_w=2.0):

 def ca_align_to_last(positions):
 def align(P, Q):
 p = P - P.mean(0,keepdims=True)
 q = Q - Q.mean(0,keepdims=True)
 return p @ cf.kabsch(p,q)
 
 pos = positions[-1,:,1,:] - positions[-1,:,1,:].mean(0,keepdims=True)
 best_2D_view = pos @ cf.kabsch(pos,pos,return_v=True)

 new_positions = []
 for i in range(len(positions)):
 new_positions.append(align(positions[i,:,1,:],best_2D_view))
 return np.asarray(new_positions)

 # align all to last recycle
 pos = ca_align_to_last(positions)

 fig, (ax1, ax2, ax3) = plt.subplots(1,3)
 fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 1, left = 0, hspace = 0, wspace = 0)
 fig.set_figwidth(13)
 fig.set_figheight(5)
 fig.set_dpi(100)

 xy_min = pos[...,:2].min() - 1
 xy_max = pos[...,:2].max() + 1

 for ax in [ax1,ax3]:
 ax.set_xlim(xy_min, xy_max)
 ax.set_ylim(xy_min, xy_max)
 ax.axis(False)

 ims=[]
 for k,(xyz,plddt) in enumerate(zip(pos,plddts)):
 ims.append([])
 im2 = ax2.plot(plddt, animated=True, color="black")
 tt1 = cf.add_text("colored by N->C", ax1)
 tt2 = cf.add_text(f"recycle={k}", ax2)
 tt3 = cf.add_text(f"pLDDT={plddt.mean():.3f}", ax3)
 ax2.set_xlabel("positions")
 ax2.set_ylabel("pLDDT")
 ax2.set_ylim(0,100)
 ims[-1] += [cf.plot_pseudo_3D(xyz, ax=ax1, line_w=line_w)]
 ims[-1] += [im2[0],tt1,tt2,tt3]
 ims[-1] += [cf.plot_pseudo_3D(xyz, c=plddt, cmin=50, cmax=90, ax=ax3, line_w=line_w)]
 
 ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)
 plt.close()
 return ani.to_html5_video()

HTML(make_animation(np.asarray(positions),
 np.asarray(plddts) * 100.0))