import torch def replace_constant(minibatch_pose_input, mask_start_frame): seq_len = minibatch_pose_input.size(1) interpolated = ( torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1 ) if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): interpolate_start = minibatch_pose_input[:, 0, :] interpolate_end = minibatch_pose_input[:, seq_len - 1, :] interpolated[:, 0, :] = interpolate_start interpolated[:, seq_len - 1, :] = interpolate_end assert torch.allclose(interpolated[:, 0, :], interpolate_start) assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) else: interpolate_start1 = minibatch_pose_input[:, 0, :] interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :] interpolated[:, 0, :] = interpolate_start1 interpolated[:, mask_start_frame, :] = interpolate_end1 interpolated[:, mask_start_frame, :] = interpolate_start2 interpolated[:, seq_len - 1, :] = interpolate_end2 assert torch.allclose(interpolated[:, 0, :], interpolate_start1) assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2) return interpolated def slerp(x, y, a): """ Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a :param x: quaternion tensor :param y: quaternion tensor :param a: indicator (between 0 and 1) of completion of the interpolation. :return: tensor of interpolation results """ device = x.device len = torch.sum(x * y, dim=-1) neg = len < 0.0 len[neg] = -len[neg] y[neg] = -y[neg] a = torch.zeros_like(x[..., 0]) + a amount0 = torch.zeros(a.shape, device=device) amount1 = torch.zeros(a.shape, device=device) linear = (1.0 - len) < 0.01 omegas = torch.arccos(len[~linear]) sinoms = torch.sin(omegas) amount0[linear] = 1.0 - a[linear] amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms amount1[linear] = a[linear] amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms # res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y return res def slerp_input_repr(minibatch_pose_input, mask_start_frame): seq_len = minibatch_pose_input.size(1) minibatch_pose_input = minibatch_pose_input.reshape( minibatch_pose_input.size(0), seq_len, -1, 4 ) interpolated = torch.zeros_like( minibatch_pose_input, device=minibatch_pose_input.device ) if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): interpolate_start = minibatch_pose_input[:, 0:1] interpolate_end = minibatch_pose_input[:, seq_len - 1 :] for i in range(seq_len): dt = 1 / (seq_len - 1) interpolated[:, i : i + 1, :] = slerp( interpolate_start, interpolate_end, dt * i ) assert torch.allclose(interpolated[:, 0:1], interpolate_start) assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end) else: interpolate_start1 = minibatch_pose_input[:, 0:1] interpolate_end1 = minibatch_pose_input[ :, mask_start_frame : mask_start_frame + 1 ] interpolate_start2 = minibatch_pose_input[ :, mask_start_frame : mask_start_frame + 1 ] interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :] for i in range(mask_start_frame + 1): dt = 1 / mask_start_frame interpolated[:, i : i + 1, :] = slerp( interpolate_start1, interpolate_end1, dt * i ) assert torch.allclose(interpolated[:, 0:1], interpolate_start1) assert torch.allclose( interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1 ) for i in range(mask_start_frame, seq_len): dt = 1 / (seq_len - mask_start_frame - 1) interpolated[:, i : i + 1, :] = slerp( interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) ) assert torch.allclose( interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2 ) assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2) interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3) return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1) def lerp_input_repr(minibatch_pose_input, mask_start_frame): seq_len = minibatch_pose_input.size(1) interpolated = torch.zeros_like( minibatch_pose_input, device=minibatch_pose_input.device ) if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): interpolate_start = minibatch_pose_input[:, 0, :] interpolate_end = minibatch_pose_input[:, seq_len - 1, :] for i in range(seq_len): dt = 1 / (seq_len - 1) interpolated[:, i, :] = torch.lerp( interpolate_start, interpolate_end, dt * i ) assert torch.allclose(interpolated[:, 0, :], interpolate_start) assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) else: interpolate_start1 = minibatch_pose_input[:, 0, :] interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] interpolate_end2 = minibatch_pose_input[:, -1, :] for i in range(mask_start_frame + 1): dt = 1 / mask_start_frame interpolated[:, i, :] = torch.lerp( interpolate_start1, interpolate_end1, dt * i ) assert torch.allclose(interpolated[:, 0, :], interpolate_start1) assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) for i in range(mask_start_frame, seq_len): dt = 1 / (seq_len - mask_start_frame - 1) interpolated[:, i, :] = torch.lerp( interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) ) assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) assert torch.allclose(interpolated[:, -1, :], interpolate_end2) return interpolated def vectorize_representation(global_position, global_rotation): batch_size = global_position.shape[0] seq_len = global_position.shape[1] global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous() global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous() global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2) return global_pose_vec_gt