Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Transformations for 3D coordinates. | |
This Module contains objects for representing Vectors (Vecs), Rotation Matrices | |
(Rots) and proper Rigid transformation (Rigids). These are represented as | |
named tuples with arrays for each entry, for example a set of | |
[N, M] points would be represented as a Vecs object with arrays of shape [N, M] | |
for x, y and z. | |
This is being done to improve readability by making it very clear what objects | |
are geometric objects rather than relying on comments and array shapes. | |
Another reason for this is to avoid using matrix | |
multiplication primitives like matmul or einsum, on modern accelerator hardware | |
these can end up on specialized cores such as tensor cores on GPU or the MXU on | |
cloud TPUs, this often involves lower computational precision which can be | |
problematic for coordinate geometry. Also these cores are typically optimized | |
for larger matrices than 3 dimensional, this code is written to avoid any | |
unintended use of these cores on both GPUs and TPUs. | |
""" | |
import collections | |
from typing import List | |
from alphafold.model import quat_affine | |
import jax.numpy as jnp | |
import tree | |
# Array of 3-component vectors, stored as individual array for | |
# each component. | |
Vecs = collections.namedtuple('Vecs', ['x', 'y', 'z']) | |
# Array of 3x3 rotation matrices, stored as individual array for | |
# each component. | |
Rots = collections.namedtuple('Rots', ['xx', 'xy', 'xz', | |
'yx', 'yy', 'yz', | |
'zx', 'zy', 'zz']) | |
# Array of rigid 3D transformations, stored as array of rotations and | |
# array of translations. | |
Rigids = collections.namedtuple('Rigids', ['rot', 'trans']) | |
def squared_difference(x, y): | |
return jnp.square(x - y) | |
def invert_rigids(r: Rigids) -> Rigids: | |
"""Computes group inverse of rigid transformations 'r'.""" | |
inv_rots = invert_rots(r.rot) | |
t = rots_mul_vecs(inv_rots, r.trans) | |
inv_trans = Vecs(-t.x, -t.y, -t.z) | |
return Rigids(inv_rots, inv_trans) | |
def invert_rots(m: Rots) -> Rots: | |
"""Computes inverse of rotations 'm'.""" | |
return Rots(m.xx, m.yx, m.zx, | |
m.xy, m.yy, m.zy, | |
m.xz, m.yz, m.zz) | |
def rigids_from_3_points( | |
point_on_neg_x_axis: Vecs, # shape (...) | |
origin: Vecs, # shape (...) | |
point_on_xy_plane: Vecs, # shape (...) | |
) -> Rigids: # shape (...) | |
"""Create Rigids from 3 points. | |
Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points" | |
This creates a set of rigid transformations from 3 points by Gram Schmidt | |
orthogonalization. | |
Args: | |
point_on_neg_x_axis: Vecs corresponding to points on the negative x axis | |
origin: Origin of resulting rigid transformations | |
point_on_xy_plane: Vecs corresponding to points in the xy plane | |
Returns: | |
Rigid transformations from global frame to local frames derived from | |
the input points. | |
""" | |
m = rots_from_two_vecs( | |
e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis), | |
e1_unnormalized=vecs_sub(point_on_xy_plane, origin)) | |
return Rigids(rot=m, trans=origin) | |
def rigids_from_list(l: List[jnp.ndarray]) -> Rigids: | |
"""Converts flat list of arrays to rigid transformations.""" | |
assert len(l) == 12 | |
return Rigids(Rots(*(l[:9])), Vecs(*(l[9:]))) | |
def rigids_from_quataffine(a: quat_affine.QuatAffine) -> Rigids: | |
"""Converts QuatAffine object to the corresponding Rigids object.""" | |
return Rigids(Rots(*tree.flatten(a.rotation)), | |
Vecs(*a.translation)) | |
def rigids_from_tensor4x4( | |
m: jnp.ndarray # shape (..., 4, 4) | |
) -> Rigids: # shape (...) | |
"""Construct Rigids object from an 4x4 array. | |
Here the 4x4 is representing the transformation in homogeneous coordinates. | |
Args: | |
m: Array representing transformations in homogeneous coordinates. | |
Returns: | |
Rigids object corresponding to transformations m | |
""" | |
assert m.shape[-1] == 4 | |
assert m.shape[-2] == 4 | |
return Rigids( | |
Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], | |
m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], | |
m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]), | |
Vecs(m[..., 0, 3], m[..., 1, 3], m[..., 2, 3])) | |
def rigids_from_tensor_flat9( | |
m: jnp.ndarray # shape (..., 9) | |
) -> Rigids: # shape (...) | |
"""Flat9 encoding: first two columns of rotation matrix + translation.""" | |
assert m.shape[-1] == 9 | |
e0 = Vecs(m[..., 0], m[..., 1], m[..., 2]) | |
e1 = Vecs(m[..., 3], m[..., 4], m[..., 5]) | |
trans = Vecs(m[..., 6], m[..., 7], m[..., 8]) | |
return Rigids(rot=rots_from_two_vecs(e0, e1), | |
trans=trans) | |
def rigids_from_tensor_flat12( | |
m: jnp.ndarray # shape (..., 12) | |
) -> Rigids: # shape (...) | |
"""Flat12 encoding: rotation matrix (9 floats) + translation (3 floats).""" | |
assert m.shape[-1] == 12 | |
x = jnp.moveaxis(m, -1, 0) # Unstack | |
return Rigids(Rots(*x[:9]), Vecs(*x[9:])) | |
def rigids_mul_rigids(a: Rigids, b: Rigids) -> Rigids: | |
"""Group composition of Rigids 'a' and 'b'.""" | |
return Rigids( | |
rots_mul_rots(a.rot, b.rot), | |
vecs_add(a.trans, rots_mul_vecs(a.rot, b.trans))) | |
def rigids_mul_rots(r: Rigids, m: Rots) -> Rigids: | |
"""Compose rigid transformations 'r' with rotations 'm'.""" | |
return Rigids(rots_mul_rots(r.rot, m), r.trans) | |
def rigids_mul_vecs(r: Rigids, v: Vecs) -> Vecs: | |
"""Apply rigid transforms 'r' to points 'v'.""" | |
return vecs_add(rots_mul_vecs(r.rot, v), r.trans) | |
def rigids_to_list(r: Rigids) -> List[jnp.ndarray]: | |
"""Turn Rigids into flat list, inverse of 'rigids_from_list'.""" | |
return list(r.rot) + list(r.trans) | |
def rigids_to_quataffine(r: Rigids) -> quat_affine.QuatAffine: | |
"""Convert Rigids r into QuatAffine, inverse of 'rigids_from_quataffine'.""" | |
return quat_affine.QuatAffine( | |
quaternion=None, | |
rotation=[[r.rot.xx, r.rot.xy, r.rot.xz], | |
[r.rot.yx, r.rot.yy, r.rot.yz], | |
[r.rot.zx, r.rot.zy, r.rot.zz]], | |
translation=[r.trans.x, r.trans.y, r.trans.z]) | |
def rigids_to_tensor_flat9( | |
r: Rigids # shape (...) | |
) -> jnp.ndarray: # shape (..., 9) | |
"""Flat9 encoding: first two columns of rotation matrix + translation.""" | |
return jnp.stack( | |
[r.rot.xx, r.rot.yx, r.rot.zx, r.rot.xy, r.rot.yy, r.rot.zy] | |
+ list(r.trans), axis=-1) | |
def rigids_to_tensor_flat12( | |
r: Rigids # shape (...) | |
) -> jnp.ndarray: # shape (..., 12) | |
"""Flat12 encoding: rotation matrix (9 floats) + translation (3 floats).""" | |
return jnp.stack(list(r.rot) + list(r.trans), axis=-1) | |
def rots_from_tensor3x3( | |
m: jnp.ndarray, # shape (..., 3, 3) | |
) -> Rots: # shape (...) | |
"""Convert rotations represented as (3, 3) array to Rots.""" | |
assert m.shape[-1] == 3 | |
assert m.shape[-2] == 3 | |
return Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], | |
m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], | |
m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]) | |
def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots: | |
"""Create rotation matrices from unnormalized vectors for the x and y-axes. | |
This creates a rotation matrix from two vectors using Gram-Schmidt | |
orthogonalization. | |
Args: | |
e0_unnormalized: vectors lying along x-axis of resulting rotation | |
e1_unnormalized: vectors lying in xy-plane of resulting rotation | |
Returns: | |
Rotations resulting from Gram-Schmidt procedure. | |
""" | |
# Normalize the unit vector for the x-axis, e0. | |
e0 = vecs_robust_normalize(e0_unnormalized) | |
# make e1 perpendicular to e0. | |
c = vecs_dot_vecs(e1_unnormalized, e0) | |
e1 = Vecs(e1_unnormalized.x - c * e0.x, | |
e1_unnormalized.y - c * e0.y, | |
e1_unnormalized.z - c * e0.z) | |
e1 = vecs_robust_normalize(e1) | |
# Compute e2 as cross product of e0 and e1. | |
e2 = vecs_cross_vecs(e0, e1) | |
return Rots(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) | |
def rots_mul_rots(a: Rots, b: Rots) -> Rots: | |
"""Composition of rotations 'a' and 'b'.""" | |
c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx)) | |
c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy)) | |
c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz)) | |
return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) | |
def rots_mul_vecs(m: Rots, v: Vecs) -> Vecs: | |
"""Apply rotations 'm' to vectors 'v'.""" | |
return Vecs(m.xx * v.x + m.xy * v.y + m.xz * v.z, | |
m.yx * v.x + m.yy * v.y + m.yz * v.z, | |
m.zx * v.x + m.zy * v.y + m.zz * v.z) | |
def vecs_add(v1: Vecs, v2: Vecs) -> Vecs: | |
"""Add two vectors 'v1' and 'v2'.""" | |
return Vecs(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z) | |
def vecs_dot_vecs(v1: Vecs, v2: Vecs) -> jnp.ndarray: | |
"""Dot product of vectors 'v1' and 'v2'.""" | |
return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z | |
def vecs_cross_vecs(v1: Vecs, v2: Vecs) -> Vecs: | |
"""Cross product of vectors 'v1' and 'v2'.""" | |
return Vecs(v1.y * v2.z - v1.z * v2.y, | |
v1.z * v2.x - v1.x * v2.z, | |
v1.x * v2.y - v1.y * v2.x) | |
def vecs_from_tensor(x: jnp.ndarray # shape (..., 3) | |
) -> Vecs: # shape (...) | |
"""Converts from tensor of shape (3,) to Vecs.""" | |
num_components = x.shape[-1] | |
assert num_components == 3 | |
return Vecs(x[..., 0], x[..., 1], x[..., 2]) | |
def vecs_robust_normalize(v: Vecs, epsilon: float = 1e-8) -> Vecs: | |
"""Normalizes vectors 'v'. | |
Args: | |
v: vectors to be normalized. | |
epsilon: small regularizer added to squared norm before taking square root. | |
Returns: | |
normalized vectors | |
""" | |
norms = vecs_robust_norm(v, epsilon) | |
return Vecs(v.x / norms, v.y / norms, v.z / norms) | |
def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> jnp.ndarray: | |
"""Computes norm of vectors 'v'. | |
Args: | |
v: vectors to be normalized. | |
epsilon: small regularizer added to squared norm before taking square root. | |
Returns: | |
norm of 'v' | |
""" | |
return jnp.sqrt(jnp.square(v.x) + jnp.square(v.y) + jnp.square(v.z) + epsilon) | |
def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs: | |
"""Computes v1 - v2.""" | |
return Vecs(v1.x - v2.x, v1.y - v2.y, v1.z - v2.z) | |
def vecs_squared_distance(v1: Vecs, v2: Vecs) -> jnp.ndarray: | |
"""Computes squared euclidean difference between 'v1' and 'v2'.""" | |
return (squared_difference(v1.x, v2.x) + | |
squared_difference(v1.y, v2.y) + | |
squared_difference(v1.z, v2.z)) | |
def vecs_to_tensor(v: Vecs # shape (...) | |
) -> jnp.ndarray: # shape(..., 3) | |
"""Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'.""" | |
return jnp.stack([v.x, v.y, v.z], axis=-1) | |