PSHuman / lib /net /
fffiloni's picture
Migrated from GitHub
2252f3d verified
history blame
4.93 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
# Contact: [email protected]
import torch
def index(feat, uv):
:param feat: [B, C, H, W] image features
:param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
:return: [B, C, N] image features at the uv coordinates
uv = uv.transpose(1, 2) # [B, N, 2]
(B, N, _) = uv.shape
C = feat.shape[1]
if uv.shape[-1] == 3:
# uv = uv[:,:,[2,1,0]]
# uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
uv = uv.unsqueeze(2) # [B, N, 1, 2]
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
# for old versions, simply remove the aligned_corners argument.
samples = torch.nn.functional.grid_sample(
feat, uv, align_corners=True) # [B, C, N, 1]
#samples = grid_sample(feat, uv) # [B, C, N, 1]
return samples.view(B, C, N) # [B, C, N]
def grid_sample(image, optical):
N, C, IH, IW = image.shape
_, H, W, _ = optical.shape
ix = optical[..., 0]
iy = optical[..., 1]
ix = ((ix + 1) / 2) * (IW-1);
iy = ((iy + 1) / 2) * (IH-1);
with torch.no_grad():
ix_nw = torch.floor(ix);
iy_nw = torch.floor(iy);
ix_ne = ix_nw + 1;
iy_ne = iy_nw;
ix_sw = ix_nw;
iy_sw = iy_nw + 1;
ix_se = ix_nw + 1;
iy_se = iy_nw + 1;
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
with torch.no_grad():
torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
torch.clamp(iy_nw, 0, IH-1, out=iy_nw)
torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
torch.clamp(iy_ne, 0, IH-1, out=iy_ne)
torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
torch.clamp(iy_sw, 0, IH-1, out=iy_sw)
torch.clamp(ix_se, 0, IW-1, out=ix_se)
torch.clamp(iy_se, 0, IH-1, out=iy_se)
image = image.view(N, C, IH * IW)
nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))
out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) +
ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
se_val.view(N, C, H, W) * se.view(N, 1, H, W))
return out_val
def orthogonal(points, calibrations, transforms=None):
Compute the orthogonal projections of 3D points into the image plane by given projection matrix
:param points: [B, 3, N] Tensor of 3D points
:param calibrations: [B, 3, 4] Tensor of projection matrix
:param transforms: [B, 2, 3] Tensor of image transform matrix
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
return pts
def perspective(points, calibrations, transforms=None):
Compute the perspective projections of 3D points into the image plane by given projection matrix
:param points: [Bx3xN] Tensor of 3D points
:param calibrations: [Bx3x4] Tensor of projection matrix
:param transforms: [Bx2x3] Tensor of image transform matrix
:return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
xy = homo[:, :2, :] / homo[:, 2:3, :]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
xy = torch.baddbmm(shift, scale, xy)
xyz =[xy, homo[:, 2:3, :]], 1)
return xyz