{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "AlphaFold_single.ipynb", "provenance": [], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "#AlphaFold - single sequence input\n", "- WARNING - For DEMO and educational purposes only. \n", "- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could potentially be useful for evaluating *de novo* designed proteins.\n" ], "metadata": { "id": "VpfCw7IzVHXv" } }, { "cell_type": "code", "source": [ "#@title Setup\n", "from IPython.utils import io\n", "import os,sys,re\n", "import tensorflow as tf\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "with io.capture_output() as captured:\n", " if not os.path.isdir(\"af_backprop\"):\n", " %shell git clone -b beta https://github.com/sokrypton/af_backprop.git\n", " %shell pip -q install biopython dm-haiku ml-collections py3Dmol\n", " %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py\n", " if not os.path.isdir(\"params\"):\n", " %shell mkdir params\n", " %shell curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", "\n", "try:\n", " # check if TPU is available\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", " print('Running on TPU')\n", " DEVICE = \"tpu\"\n", "except:\n", " if jax.local_devices()[0].platform == 'cpu':\n", " print(\"WARNING: no GPU detected, will be using CPU\")\n", " DEVICE = \"cpu\"\n", " else:\n", " print('Running on GPU')\n", " DEVICE = \"gpu\"\n", " # disable GPU on tensorflow\n", " tf.config.set_visible_devices([], 'GPU')\n", "\n", "sys.path.append('/content/af_backprop')\n", "# import libraries\n", "from utils import update_seq, update_aatype, get_plddt, get_pae\n", "import colabfold as cf\n", "from alphafold.common import protein\n", "from alphafold.data import pipeline\n", "from alphafold.model import data, config, model\n", "from alphafold.common import residue_constants\n", "\n", "def clear_mem():\n", " backend = jax.lib.xla_bridge.get_backend()\n", " for buf in backend.live_buffers(): buf.delete()\n", "\n", "def setup_model(max_len, model_name=\"model_2_ptm\"):\n", "\n", " clear_mem()\n", "\n", " # setup model\n", " cfg = config.model_config(\"model_5_ptm\")\n", " cfg.model.num_recycle = 0\n", " cfg.data.common.num_recycle = 0\n", " cfg.data.eval.max_msa_clusters = 1\n", " cfg.data.common.max_extra_msa = 1\n", " cfg.data.eval.masked_msa_replace_fraction = 0\n", " cfg.model.global_config.subbatch_size = None\n", " model_params = data.get_model_haiku_params(model_name=model_name, data_dir=\".\")\n", " model_runner = model.RunModel(cfg, model_params, is_training=False)\n", "\n", " seq = \"A\" * max_len\n", " length = len(seq)\n", " feature_dict = {\n", " **pipeline.make_sequence_features(sequence=seq, description=\"none\", num_res=length),\n", " **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0]*length]])\n", " }\n", " inputs = model_runner.process_features(feature_dict,random_seed=0)\n", "\n", " def runner(seq, opt):\n", " # update sequence\n", " inputs = opt[\"inputs\"]\n", " inputs.update(opt[\"prev\"])\n", " update_seq(seq, inputs)\n", " update_aatype(inputs[\"target_feat\"][...,1:], inputs)\n", "\n", " # mask prediction\n", " mask = seq.sum(-1)\n", " inputs[\"seq_mask\"] = inputs[\"seq_mask\"].at[:].set(mask)\n", " inputs[\"msa_mask\"] = inputs[\"msa_mask\"].at[:].set(mask)\n", " inputs[\"residue_index\"] = jnp.where(mask==1,inputs[\"residue_index\"],0)\n", "\n", " # get prediction\n", " key = jax.random.PRNGKey(0)\n", " outputs = model_runner.apply(opt[\"params\"], key, inputs)\n", "\n", " prev = {\"init_msa_first_row\":outputs['representations']['msa_first_row'][None],\n", " \"init_pair\":outputs['representations']['pair'][None],\n", " \"init_pos\":outputs['structure_module']['final_atom_positions'][None]}\n", " \n", " aux = {\"final_atom_positions\":outputs[\"structure_module\"][\"final_atom_positions\"],\n", " \"final_atom_mask\":outputs[\"structure_module\"][\"final_atom_mask\"],\n", " \"plddt\":get_plddt(outputs),\"pae\":get_pae(outputs),\n", " \"inputs\":inputs, \"prev\":prev}\n", " return aux\n", "\n", " return jax.jit(runner), {\"inputs\":inputs,\"params\":model_params}\n", "\n", "MAX_LEN = 50\n", "RUNNER, OPT = setup_model(MAX_LEN)" ], "metadata": { "cellView": "form", "id": "24ybo88aBiSU" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%time\n", "#@title Enter the amino acid sequence to fold ⬇️\n", "\n", "sequence = 'GGGGGGGGGGGGGGGGGGGG' #@param {type:\"string\"}\n", "recycles = 0 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\"] {type:\"raw\"}\n", "SEQ = re.sub(\"[^A-Z]\", \"\", sequence.upper())\n", "LEN = len(SEQ)\n", "if LEN > MAX_LEN:\n", " print(\"recompiling...\")\n", " MAX_LEN = LEN\n", " RUNNER, OPT = setup_model(MAX_LEN)\n", "\n", "x = np.array([residue_constants.restype_order.get(aa,0) for aa in SEQ])\n", "x = np.pad(x,[0,MAX_LEN-LEN],constant_values=-1)\n", "x = jax.nn.one_hot(x,20)\n", "\n", "OPT[\"prev\"] = {'init_msa_first_row': np.zeros([1, MAX_LEN, 256]),\n", " 'init_pair': np.zeros([1, MAX_LEN, MAX_LEN, 128]),\n", " 'init_pos': np.zeros([1, MAX_LEN, 37, 3])}\n", "\n", "positions = []\n", "plddts = []\n", "for r in range(recycles+1):\n", " outs = RUNNER(x, OPT)\n", " outs = jax.tree_map(lambda x:np.asarray(x), outs)\n", " positions.append(outs[\"prev\"][\"init_pos\"][0,:LEN])\n", " plddts.append(outs[\"plddt\"][:LEN])\n", " OPT[\"prev\"] = outs[\"prev\"]\n", " if recycles > 0:\n", " print(r, plddts[-1].mean())" ], "metadata": { "cellView": "form", "id": "cAoC4ar8G7ZH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title Display 3D structure {run: \"auto\"}\n", "color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n", "show_sidechains = True #@param {type:\"boolean\"}\n", "show_mainchains = False #@param {type:\"boolean\"}\n", "#@markdown - TIP - hold mouse over aminoacid to get name and position number\n", "\n", "def save_pdb(outs, filename):\n", " '''save pdb coordinates'''\n", " p = {\"residue_index\":outs[\"inputs\"][\"residue_index\"][0][:LEN] + 1,\n", " \"aatype\":outs[\"inputs\"][\"aatype\"].argmax(-1)[0][:LEN],\n", " \"atom_positions\":outs[\"final_atom_positions\"][:LEN],\n", " \"atom_mask\":outs[\"final_atom_mask\"][:LEN]}\n", " b_factors = 100.0 * outs[\"plddt\"][:LEN,None] * p[\"atom_mask\"]\n", " p = protein.Protein(**p,b_factors=b_factors)\n", " pdb_lines = protein.to_pdb(p)\n", " with open(filename, 'w') as f:\n", " f.write(pdb_lines)\n", "\n", "save_pdb(outs,\"out.pdb\")\n", "num_res = int(outs[\"inputs\"][\"aatype\"][0].sum())\n", "\n", "v = cf.show_pdb(\"out.pdb\", show_sidechains, show_mainchains, color,\n", " color_HP=True, size=(800,480)) \n", "v.setHoverable({},\n", " True,\n", " '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(\" \"+atom.resn+\":\"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',\n", " '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')\n", "v.show() \n", "\n", "if color == \"lDDT\":\n", " cf.plot_plddt_legend().show() \n", "if \"pae\" in outs:\n", " cf.plot_confidence(outs[\"plddt\"][:LEN]*100, outs[\"pae\"][:LEN,:LEN]).show()\n", "else:\n", " cf.plot_confidence(outs[\"plddt\"][:LEN]*100).show()" ], "metadata": { "cellView": "form", "id": "-KbUGG4ZOp0J" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title Animate\n", "#@markdown - Animate trajectory if more than 0 recycle(s)\n", "import matplotlib\n", "from matplotlib import animation\n", "import matplotlib.pyplot as plt\n", "from IPython.display import HTML\n", "\n", "def make_animation(positions, plddts=None, line_w=2.0):\n", "\n", " def ca_align_to_last(positions):\n", " def align(P, Q):\n", " p = P - P.mean(0,keepdims=True)\n", " q = Q - Q.mean(0,keepdims=True)\n", " return p @ cf.kabsch(p,q)\n", " \n", " pos = positions[-1,:,1,:] - positions[-1,:,1,:].mean(0,keepdims=True)\n", " best_2D_view = pos @ cf.kabsch(pos,pos,return_v=True)\n", "\n", " new_positions = []\n", " for i in range(len(positions)):\n", " new_positions.append(align(positions[i,:,1,:],best_2D_view))\n", " return np.asarray(new_positions)\n", "\n", " # align all to last recycle\n", " pos = ca_align_to_last(positions)\n", "\n", " fig, (ax1, ax2, ax3) = plt.subplots(1,3)\n", " fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 1, left = 0, hspace = 0, wspace = 0)\n", " fig.set_figwidth(13)\n", " fig.set_figheight(5)\n", " fig.set_dpi(100)\n", "\n", " xy_min = pos[...,:2].min() - 1\n", " xy_max = pos[...,:2].max() + 1\n", "\n", " for ax in [ax1,ax3]:\n", " ax.set_xlim(xy_min, xy_max)\n", " ax.set_ylim(xy_min, xy_max)\n", " ax.axis(False)\n", "\n", " ims=[]\n", " for k,(xyz,plddt) in enumerate(zip(pos,plddts)):\n", " ims.append([])\n", " im2 = ax2.plot(plddt, animated=True, color=\"black\")\n", " tt1 = cf.add_text(\"colored by N->C\", ax1)\n", " tt2 = cf.add_text(f\"recycle={k}\", ax2)\n", " tt3 = cf.add_text(f\"pLDDT={plddt.mean():.3f}\", ax3)\n", " ax2.set_xlabel(\"positions\")\n", " ax2.set_ylabel(\"pLDDT\")\n", " ax2.set_ylim(0,100)\n", " ims[-1] += [cf.plot_pseudo_3D(xyz, ax=ax1, line_w=line_w)]\n", " ims[-1] += [im2[0],tt1,tt2,tt3]\n", " ims[-1] += [cf.plot_pseudo_3D(xyz, c=plddt, cmin=50, cmax=90, ax=ax3, line_w=line_w)]\n", " \n", " ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)\n", " plt.close()\n", " return ani.to_html5_video()\n", "\n", "HTML(make_animation(np.asarray(positions),\n", " np.asarray(plddts) * 100.0))" ], "metadata": { "cellView": "form", "id": "tdjdC0KFPjWw" }, "execution_count": null, "outputs": [] } ] }