Spaces:
Running
Running
import torch | |
def replace_constant(minibatch_pose_input, mask_start_frame): | |
seq_len = minibatch_pose_input.size(1) | |
interpolated = ( | |
torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1 | |
) | |
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): | |
interpolate_start = minibatch_pose_input[:, 0, :] | |
interpolate_end = minibatch_pose_input[:, seq_len - 1, :] | |
interpolated[:, 0, :] = interpolate_start | |
interpolated[:, seq_len - 1, :] = interpolate_end | |
assert torch.allclose(interpolated[:, 0, :], interpolate_start) | |
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) | |
else: | |
interpolate_start1 = minibatch_pose_input[:, 0, :] | |
interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] | |
interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] | |
interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :] | |
interpolated[:, 0, :] = interpolate_start1 | |
interpolated[:, mask_start_frame, :] = interpolate_end1 | |
interpolated[:, mask_start_frame, :] = interpolate_start2 | |
interpolated[:, seq_len - 1, :] = interpolate_end2 | |
assert torch.allclose(interpolated[:, 0, :], interpolate_start1) | |
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) | |
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) | |
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2) | |
return interpolated | |
def slerp(x, y, a): | |
""" | |
Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a | |
:param x: quaternion tensor | |
:param y: quaternion tensor | |
:param a: indicator (between 0 and 1) of completion of the interpolation. | |
:return: tensor of interpolation results | |
""" | |
device = x.device | |
len = torch.sum(x * y, dim=-1) | |
neg = len < 0.0 | |
len[neg] = -len[neg] | |
y[neg] = -y[neg] | |
a = torch.zeros_like(x[..., 0]) + a | |
amount0 = torch.zeros(a.shape, device=device) | |
amount1 = torch.zeros(a.shape, device=device) | |
linear = (1.0 - len) < 0.01 | |
omegas = torch.arccos(len[~linear]) | |
sinoms = torch.sin(omegas) | |
amount0[linear] = 1.0 - a[linear] | |
amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms | |
amount1[linear] = a[linear] | |
amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms | |
# res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y | |
res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y | |
return res | |
def slerp_input_repr(minibatch_pose_input, mask_start_frame): | |
seq_len = minibatch_pose_input.size(1) | |
minibatch_pose_input = minibatch_pose_input.reshape( | |
minibatch_pose_input.size(0), seq_len, -1, 4 | |
) | |
interpolated = torch.zeros_like( | |
minibatch_pose_input, device=minibatch_pose_input.device | |
) | |
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): | |
interpolate_start = minibatch_pose_input[:, 0:1] | |
interpolate_end = minibatch_pose_input[:, seq_len - 1 :] | |
for i in range(seq_len): | |
dt = 1 / (seq_len - 1) | |
interpolated[:, i : i + 1, :] = slerp( | |
interpolate_start, interpolate_end, dt * i | |
) | |
assert torch.allclose(interpolated[:, 0:1], interpolate_start) | |
assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end) | |
else: | |
interpolate_start1 = minibatch_pose_input[:, 0:1] | |
interpolate_end1 = minibatch_pose_input[ | |
:, mask_start_frame : mask_start_frame + 1 | |
] | |
interpolate_start2 = minibatch_pose_input[ | |
:, mask_start_frame : mask_start_frame + 1 | |
] | |
interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :] | |
for i in range(mask_start_frame + 1): | |
dt = 1 / mask_start_frame | |
interpolated[:, i : i + 1, :] = slerp( | |
interpolate_start1, interpolate_end1, dt * i | |
) | |
assert torch.allclose(interpolated[:, 0:1], interpolate_start1) | |
assert torch.allclose( | |
interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1 | |
) | |
for i in range(mask_start_frame, seq_len): | |
dt = 1 / (seq_len - mask_start_frame - 1) | |
interpolated[:, i : i + 1, :] = slerp( | |
interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) | |
) | |
assert torch.allclose( | |
interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2 | |
) | |
assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2) | |
interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3) | |
return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1) | |
def lerp_input_repr(minibatch_pose_input, mask_start_frame): | |
seq_len = minibatch_pose_input.size(1) | |
interpolated = torch.zeros_like( | |
minibatch_pose_input, device=minibatch_pose_input.device | |
) | |
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): | |
interpolate_start = minibatch_pose_input[:, 0, :] | |
interpolate_end = minibatch_pose_input[:, seq_len - 1, :] | |
for i in range(seq_len): | |
dt = 1 / (seq_len - 1) | |
interpolated[:, i, :] = torch.lerp( | |
interpolate_start, interpolate_end, dt * i | |
) | |
assert torch.allclose(interpolated[:, 0, :], interpolate_start) | |
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) | |
else: | |
interpolate_start1 = minibatch_pose_input[:, 0, :] | |
interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] | |
interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] | |
interpolate_end2 = minibatch_pose_input[:, -1, :] | |
for i in range(mask_start_frame + 1): | |
dt = 1 / mask_start_frame | |
interpolated[:, i, :] = torch.lerp( | |
interpolate_start1, interpolate_end1, dt * i | |
) | |
assert torch.allclose(interpolated[:, 0, :], interpolate_start1) | |
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) | |
for i in range(mask_start_frame, seq_len): | |
dt = 1 / (seq_len - mask_start_frame - 1) | |
interpolated[:, i, :] = torch.lerp( | |
interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) | |
) | |
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) | |
assert torch.allclose(interpolated[:, -1, :], interpolate_end2) | |
return interpolated | |
def vectorize_representation(global_position, global_rotation): | |
batch_size = global_position.shape[0] | |
seq_len = global_position.shape[1] | |
global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous() | |
global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous() | |
global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2) | |
return global_pose_vec_gt | |