|
|
|
|
|
import torch |
|
|
|
|
|
def precompute_freqs_cis( |
|
dim: int, |
|
end: int, |
|
theta: float = 10000.0, |
|
use_scaled: bool = False, |
|
dtype: torch.dtype = torch.float32, |
|
) -> torch.Tensor: |
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim)) |
|
t = torch.arange(end, dtype=dtype).unsqueeze(1) |
|
freqs = t * freqs.unsqueeze(0) |
|
freqs = torch.exp(1j * freqs) |
|
return torch.stack([freqs.real, freqs.imag], dim=-1) |
|
|
|
|
|
def apply_rotary_emb( |
|
x: torch.Tensor, |
|
freqs_cis: torch.Tensor, |
|
position_ids: torch.Tensor, |
|
num_heads: int, |
|
rot_dim: int = 32, |
|
interleave: bool = False, |
|
) -> torch.Tensor: |
|
assert rot_dim == freqs_cis.shape[-2] * 2 |
|
assert num_heads == x.shape[1] |
|
|
|
x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:] |
|
|
|
if interleave: |
|
xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0] |
|
xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1] |
|
else: |
|
d_q = x_rot.shape[-1] // 2 |
|
xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:] |
|
|
|
freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0) |
|
freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin |
|
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos |
|
xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2) |
|
|
|
return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1) |
|
|