HaWoR / infiller /lib /model /preprocess.py
ThunderVVV's picture
update
5f028d6
raw
history blame
7.12 kB
import torch
def replace_constant(minibatch_pose_input, mask_start_frame):
seq_len = minibatch_pose_input.size(1)
interpolated = (
torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1
)
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
interpolate_start = minibatch_pose_input[:, 0, :]
interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
interpolated[:, 0, :] = interpolate_start
interpolated[:, seq_len - 1, :] = interpolate_end
assert torch.allclose(interpolated[:, 0, :], interpolate_start)
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
else:
interpolate_start1 = minibatch_pose_input[:, 0, :]
interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :]
interpolated[:, 0, :] = interpolate_start1
interpolated[:, mask_start_frame, :] = interpolate_end1
interpolated[:, mask_start_frame, :] = interpolate_start2
interpolated[:, seq_len - 1, :] = interpolate_end2
assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2)
return interpolated
def slerp(x, y, a):
"""
Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a
:param x: quaternion tensor
:param y: quaternion tensor
:param a: indicator (between 0 and 1) of completion of the interpolation.
:return: tensor of interpolation results
"""
device = x.device
len = torch.sum(x * y, dim=-1)
neg = len < 0.0
len[neg] = -len[neg]
y[neg] = -y[neg]
a = torch.zeros_like(x[..., 0]) + a
amount0 = torch.zeros(a.shape, device=device)
amount1 = torch.zeros(a.shape, device=device)
linear = (1.0 - len) < 0.01
omegas = torch.arccos(len[~linear])
sinoms = torch.sin(omegas)
amount0[linear] = 1.0 - a[linear]
amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms
amount1[linear] = a[linear]
amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms
# res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y
res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y
return res
def slerp_input_repr(minibatch_pose_input, mask_start_frame):
seq_len = minibatch_pose_input.size(1)
minibatch_pose_input = minibatch_pose_input.reshape(
minibatch_pose_input.size(0), seq_len, -1, 4
)
interpolated = torch.zeros_like(
minibatch_pose_input, device=minibatch_pose_input.device
)
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
interpolate_start = minibatch_pose_input[:, 0:1]
interpolate_end = minibatch_pose_input[:, seq_len - 1 :]
for i in range(seq_len):
dt = 1 / (seq_len - 1)
interpolated[:, i : i + 1, :] = slerp(
interpolate_start, interpolate_end, dt * i
)
assert torch.allclose(interpolated[:, 0:1], interpolate_start)
assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end)
else:
interpolate_start1 = minibatch_pose_input[:, 0:1]
interpolate_end1 = minibatch_pose_input[
:, mask_start_frame : mask_start_frame + 1
]
interpolate_start2 = minibatch_pose_input[
:, mask_start_frame : mask_start_frame + 1
]
interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :]
for i in range(mask_start_frame + 1):
dt = 1 / mask_start_frame
interpolated[:, i : i + 1, :] = slerp(
interpolate_start1, interpolate_end1, dt * i
)
assert torch.allclose(interpolated[:, 0:1], interpolate_start1)
assert torch.allclose(
interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1
)
for i in range(mask_start_frame, seq_len):
dt = 1 / (seq_len - mask_start_frame - 1)
interpolated[:, i : i + 1, :] = slerp(
interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
)
assert torch.allclose(
interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2
)
assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2)
interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3)
return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1)
def lerp_input_repr(minibatch_pose_input, mask_start_frame):
seq_len = minibatch_pose_input.size(1)
interpolated = torch.zeros_like(
minibatch_pose_input, device=minibatch_pose_input.device
)
if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
interpolate_start = minibatch_pose_input[:, 0, :]
interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
for i in range(seq_len):
dt = 1 / (seq_len - 1)
interpolated[:, i, :] = torch.lerp(
interpolate_start, interpolate_end, dt * i
)
assert torch.allclose(interpolated[:, 0, :], interpolate_start)
assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
else:
interpolate_start1 = minibatch_pose_input[:, 0, :]
interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
interpolate_end2 = minibatch_pose_input[:, -1, :]
for i in range(mask_start_frame + 1):
dt = 1 / mask_start_frame
interpolated[:, i, :] = torch.lerp(
interpolate_start1, interpolate_end1, dt * i
)
assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
for i in range(mask_start_frame, seq_len):
dt = 1 / (seq_len - mask_start_frame - 1)
interpolated[:, i, :] = torch.lerp(
interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
)
assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
assert torch.allclose(interpolated[:, -1, :], interpolate_end2)
return interpolated
def vectorize_representation(global_position, global_rotation):
batch_size = global_position.shape[0]
seq_len = global_position.shape[1]
global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous()
global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous()
global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2)
return global_pose_vec_gt