from typing import Tuple import torch from einops import rearrange, repeat def flash_torch_rotate_half(x, interleaved=False): if not interleaved: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) def flash_torch_apply_rotary_emb_torch(x, cos, sin, interleaved=False): """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) """ if x.shape[-3] < cos.shape[-2]: # this fixes AR bug but NOT kv cache slicing cos = cos[..., :x.shape[1], :] sin = sin[..., :x.shape[1], :] ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") return torch.cat( [x[..., :ro_dim] * cos + flash_torch_rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1, ) def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """Applies the rotary embedding to the query and key tensors.""" x_ = torch.view_as_complex( torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1)) x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2) return x_out def rotate_half_(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_(x, cos, sin): # NOTE: This could probably be moved to Triton # Handle a possible sequence length mismatch in between q and k cos = cos[:, :, : x.shape[-2], :] sin = sin[:, :, : x.shape[-2], :] return (x * cos) + (rotate_half_(x) * sin) class StandaloneRotaryEmbedding(torch.nn.Module): """ The rotary position embeddings from RoFormer_ (Su et. al). A crucial insight from the method is that the query and keys are transformed by rotation matrices which depend on the relative positions. Other implementations are available in the Rotary Transformer repo_ and in GPT-NeoX_, GPT-NeoX was an inspiration .. _RoFormer: https://arxiv.org/abs/2104.09864 .. _repo: https://github.com/ZhuiyiTechnology/roformer .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox .. warning: Please note that this embedding is not registered on purpose, as it is transformative (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis """ def __init__(self, dim_model: int, *_, **__): super().__init__() # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model)) self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = None self._cos_cached = None self._sin_cached = None def _update_cos_sin_tables(self, x, seq_dimension=1): seq_len = x.shape[seq_dimension] # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype ): self._seq_len_cached = seq_len t = torch.arange( x.shape[seq_dimension], device=x.device, dtype=torch.float32 ) freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) return self._cos_cached, self._sin_cached def forward( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: self._cos_cached, self._sin_cached = self._update_cos_sin_tables( k, seq_dimension=-2 ) return ( apply_rotary_pos_emb_(q, self._cos_cached, self._sin_cached), apply_rotary_pos_emb_(k, self._cos_cached, self._sin_cached), )