|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Misc functions. |
|
|
|
Mostly copy-paste from torchvision references or other public repos like DETR: |
|
https://github.com/facebookresearch/detr/blob/master/util/misc.py |
|
""" |
|
import torch |
|
import math |
|
import warnings |
|
|
|
|
|
def get_1d_sincos_pos_embed(embed_dim, pos, gsd=1, ref_gsd=1): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = torch.arange(embed_dim // 2, dtype=torch.float, device=pos.device) |
|
omega /= embed_dim / 2. |
|
omega = 1. / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = torch.einsum('m,d->md', pos, omega) |
|
|
|
emb_sin = torch.sin(gsd/ref_gsd * out) |
|
emb_cos = torch.cos(gsd/ref_gsd * out) |
|
|
|
emb = torch.zeros([len(pos), embed_dim]) |
|
emb[:, 0::2] = emb_sin |
|
emb[:, 1::2] = emb_cos |
|
|
|
return emb.float() |
|
|
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
def norm_cdf(x): |
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
return tensor |
|
|
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
|
|
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|
|