You could have designed state of the art positional encoding

Published November 25, 2024
Update on GitHub

Gall's Law
A complex system that works is invariably found to have evolved from a simple system that worked
John Gall

This post walks you through the step-by-step discovery of state-of-the-art positional encoding in transformer models. We will achieve this by iteratively improving our approach to encoding position, arriving at Rotary Postional Encoding (RoPE) used in the latest LLama 3.2 release and most modern transformers. This post intends to limit the mathematical knowledge required to follow along, but some basic linear algebra, trigonometry and understanding of self attention is expected.

Problem Statement

You shall know a word by the company it keeps
John Rupert Firth

As with all problems, it is best to first start with understanding exactly what we are trying to achieve. The self attention mechanism in transformers is utilized to understand relationships between tokens in a sequence. Self attention is a set operation, which means it is permutation equivariant. If we do not enrich self attention with positional information, many important relationships are incapable of being determined.

This is best demonstrated by example.

Motivating Example

Consider this sentence with the same word in different positions:

The dog chased another dog \text{The dog chased another dog}

Intuitively, "dog" refers to two different entities. Let's see what happens if we first tokenize them, map to the real token embeddings of Llama 3.2 1B and pass them through torch.nn.MultiheadAttention.

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

model_id = "meta-llama/Llama-3.2-1B"
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)

text = "The dog chased another dog"
tokens = tok(text, return_tensors="pt")["input_ids"]
embeddings = model.embed_tokens(tokens)
hdim = embeddings.shape[-1]

W_q = nn.Linear(hdim, hdim, bias=False)
W_k = nn.Linear(hdim, hdim, bias=False)
W_v = nn.Linear(hdim, hdim, bias=False)
mha = nn.MultiheadAttention(embed_dim=hdim, num_heads=4, batch_first=True)

with torch.no_grad():
    for param in mha.parameters():
        nn.init.normal_(param, std=0.1) # Initialize weights to be non-negligible

output, _ = mha(W_q(embeddings), W_k(embeddings), W_v(embeddings))

dog1_out = output[0, 2]
dog2_out = output[0, 5]
print(f"Dog output identical?: {torch.allclose(dog1_out, dog2_out, atol=1e-6)}") #True

As we can see, without any positional information, the output of a (multi headed) self attention operation is identical for the same token in different positions, despite the tokens clearly representing distinct entities. Let's begin designing a method of enhancing self attention with positional information, such that it can determine relationships between words encoded by their positions.

How should an ideal positional encoding scheme behave?

Desirable Properties

Let's try and define some desirable properties that will make the optimization process as easy as possible.

Property 1 - Unique encoding for each position (across sequences)

Each position needs a unique encoding that remains consistent regardless of sequence length - a token at position 5 should have the same encoding whether the current sequence is of length 10 or 10,000.

Property 2 - Linear relation between two encoded positions

The relationship between positions should be mathematically simple. If we know the encoding for position pp, it should be straightforward to compute the encoding for position p+kp+k, making it easier for the model to learn positional patterns.

If you think about how we represent numbers on a number line, it's easy to understand that 5 is 2 steps away from 3, or that 10 is 5 steps from 15. The same intuitive relationship should exist in our encodings.

Property 3 - Generalizes to longer sequences than those encountered in training

To increase our models' utility in the real world, they should generalize outside their training distribution. Therefore, our encoding scheme needs to be adaptable enough to handle unexpected input lengths, without violating any of our other desirable properties.

Property 4 - Generated by a deterministic process the model can learn

It would be ideal if our positional encodings could be drawn from a deterministic process. This should allow the model to learn the mechanism behind our encoding scheme efficiently.

Property 5 - Extensible to multiple dimensions

With multimodal models becoming the norm, it is crucial that our positional encoding scheme can naturally extend from 1D1D to nDnD. This will allow models to consume data like images or brain scans, which are 2D2D and 4D4D respectively.

Now we know the ideal properties (henceforth referred to as PrnPr_n), let's start designing and iterating on our encoding scheme.

Integer Position Encoding

The first approach that may jump to mind is simply to add the integer value of the token position to each component of the token embedding, with values ranging from 0L0 \rightarrow L where LL is the length of our current sequence.

In the above animation, we create our positional encoding vector for the token chased\color{#699C52}\text{chased} from the index and add it to our token embedding. The embedding values here are a subset of the real values from Llama 3.2 1B. We can observe that they're clustered around 0. This is desirable to avoid vanishing or exploding gradients during training and therefore is something we'd like to maintain throughout the model.

It's clear that our current naïve approach is going to cause problems. The magnitude of the position value vastly exceeds the actual values of our input. This means the signal-to-noise ratio is very low, and it's hard for the model to separate the semantic information from the positional information.

With this new knowledge, a natural follow on might be to normalize the position value by 1N\frac{1}{N}. This constrains the values between 0 and 1, but introduces another problem. If we choose NN to be the length of the current sequence, then the position values will be completely different for each sequence of differing lengths, violating Pr1Pr_1.

Is there a better way to ensure our numbers are between 0 and 1? If we thought really hard about this for a while, we might come up with switching from decimal to binary numbers.

Binary Position Encoding

Instead of adding our (potentially normalized) integer position to each component of the embedding, we could instead convert it into its binary representation and s t r e t c h our value out to match our embedding dimension, as demonstrated below.

We've converted the position of interest (252) into its binary representation (11111100) and added each bit to the corresponding component of the token embedding. The least significant bit (LSB) will cycle between 0 and 1 for every subsequent token, whilst the most significant bit (MSB) will cycle every 2n12^{n-1} tokens where nn is the number of bits. You can see the positional encoding vector for different indices in the animation below [1][^1].

We've solved the value range problem, and we now have unique encodings that are consistent across different sequence lengths. What happens if we plot a low dimensional version of our token embedding and visualize the addition of our binary positional vector for different values.

We can see that the result is very "jumpy" (as we might expect from the discrete nature of binary). The optimization process likes smooth, continuous and predictable changes. Do we know any functions with similar value ranges that are smooth and continuous?

If we looked around a little, we might notice that both sin\sin and cos\cos fit the bill!

Sinusoidal positional encoding

The above animation visualizes our position embedding if each component is alternatively drawn from sin\sin and cos\cos with gradually increasing wavelengths. If you compare it with the previous animation, you'll notice a striking similarity!

We've now arrived at Sinusoidal embeddings; originally defined in the Attention is all you need paper. Let's look at the equations:

PE(pos,2i)=sin(pos100002i/d)PE(pos,2i+1)=cos(pos100002i/d) PE_{(pos,2i)} = \color{#58C4DD}\sin\left(\color{black}\frac{pos}{10000^{2i/d}}\color{#58C4DD}\right)\color{black} \\ \quad \\ PE_{(pos,2i+1)} = \color{#FC6255}\cos\left(\color{black}\frac{pos}{10000^{2i/d}}\color{#FC6255}\right)\color{black} \\

where pospos is the tokens position index, ii is the component index in the positional encoding vector, and dd is the model dimension. 10,00010,000 is the base wavelength (henceforth referred to as θ\theta), which we stretch or compress as a function of the component index. I encourage you to plug in some realistic values to get a feel for this geometric progression.

There's a few parts of this equation that are confusing at first glance. How did the authors choose 10,00010,000? Why are we using sin\sin and cos\cos for even and odd positions respectively?

It seems that using 10,00010,000 for the base wavelength was determined experimentally [2][^2]. Deciphering the usage of both sin\sin and cos\cos is more involved, but crucial for our iterative approach to understanding. The key here is our desire for a linear relation between two encoded positions Pr2Pr_2. To understand how using sin\sin and cos\cos in tandem produce this linear relation, we will have to dive into some trigonometry.

Consider a sequence of sine and cosine pairs, each associated with a frequency ωi\omega_i. Our goal is to find a linear transformation matrix M\mathbf{M} that can shift these sinusoidal functions by a fixed offset kk:

M[sin(ωip)cos(ωip)]=[sin(ωi(p+k))cos(ωi(p+k))] \mathbf{M} \cdot \begin{bmatrix} \sin(\omega_i p) \\ \cos(\omega_i p) \end{bmatrix} = \begin{bmatrix} \sin(\omega_i(p + k)) \\ \cos(\omega_i(p + k)) \end{bmatrix}

The frequencies ωi\omega_i follow a geometric progression that decreases with dimension index ii, defined as:

ωi=1100002i/d \omega_i = \frac{1}{10000^{2i/d}}

To find this transformation matrix, we can express it as a general 2×2 matrix with unknown coefficients u1u_1, v1v_1, u2u_2, and v2v_2:

[u1v1u2v2][sin(ωip)cos(ωip)]=[sin(ωi(p+k))cos(ωi(p+k))] \begin{bmatrix} u_1 & v_1 \\ u_2 & v_2 \end{bmatrix} \cdot \begin{bmatrix} \sin(\omega_i p) \\ \cos(\omega_i p) \end{bmatrix} = \begin{bmatrix} \sin(\omega_i(p+k)) \\ \cos(\omega_i(p+k)) \end{bmatrix}

By applying the trigonometric addition theorem to the right-hand side, we can expand this into:

[u1v1u2v2][sin(ωip)cos(ωip)]=[sin(ωip)cos(ωik)+cos(ωip)sin(ωik)cos(ωip)cos(ωik)sin(ωip)sin(ωik)] \begin{bmatrix} u_1 & v_1 \\ u_2 & v_2 \end{bmatrix} \cdot \begin{bmatrix} \sin(\omega_i p) \\ \cos(\omega_i p) \end{bmatrix} = \begin{bmatrix} \sin(\omega_i p)\cos(\omega_i k) + \cos(\omega_i p)\sin(\omega_i k) \\ \cos(\omega_i p)\cos(\omega_i k) - \sin(\omega_i p)\sin(\omega_i k) \end{bmatrix}

This expansion gives us a system of two equations by matching coefficients:

u1sin(ωip)+v1cos(ωip)=cos(ωik)sin(ωip)+sin(ωik)cos(ωip)u2sin(ωip)+v2cos(ωip)=sin(ωik)sin(ωip)+cos(ωik)cos(ωip) \begin{align} u_1\sin(\omega_i p) + v_1\cos(\omega_i p) &= \cos(\omega_i k)\sin(\omega_i p) + \sin(\omega_i k)\cos(\omega_i p) \\ u_2\sin(\omega_i p) + v_2\cos(\omega_i p) &= -\sin(\omega_i k)\sin(\omega_i p) + \cos(\omega_i k)\cos(\omega_i p) \end{align}

By comparing terms with sin(ωip)\sin(\omega_i p) and cos(ωip)\cos(\omega_i p) on both sides, we can solve for the unknown coefficients:

u1=cos(ωik)v1=sin(ωik)u2=sin(ωik)v2=cos(ωik) \begin{align} u_1 &= \cos(\omega_i k) & v_1 &= \sin(\omega_i k) \\ u_2 &= -\sin(\omega_i k) & v_2 &= \cos(\omega_i k) \end{align}

These solutions give us our final transformation matrix Mk\mathbf{M_k}:

Mk=[cos(ωik)sin(ωik)sin(ωik)cos(ωik)] \mathbf{M_k} = \begin{bmatrix} \cos(\omega_i k) & \sin(\omega_i k) \\ -\sin(\omega_i k) & \cos(\omega_i k) \end{bmatrix}

If you've done any game programming before, you might notice that the result of our derivation is oddly familiar. That's right, it's the Rotation Matrix! [3][^3].

So the encoding scheme designed by Noam Shazeer in Attention is all you need was already encoding relative position as a rotation back in 2017! It took another 4 years to go from Sinusoidal Encoding to RoPE, despite rotations already being on the table...

Absolute vs Relative Position Encoding

With the knowledge in hand that rotations are important here, let's return to our motivating example and try to discover some intuitions for our next iteration.

01234The dog chased another dog-2-1012The dog chased another dog \begin{align*} &\hspace{0.7em}0 \hspace{1.4em} 1 \hspace{2em} 2 \hspace{2.6em} 3 \hspace{2.4em} 4\\ &\text{The dog chased another dog} \\ \\ &\hspace{0.3em}\text{-2} \hspace{1.4em} \text{-1} \hspace{1.7em} 0 \hspace{2.6em} 1 \hspace{2.4em} 2\\ &\text{The dog \color{#699C52}chased \color{black}another dog} \end{align*}

Above we can see the absolute positions of our tokens, and the relative positions from chased\color{#699C52}\text{chased} to every other token. With Sinusoidal Encoding, we generated a separate vector which represents the absolute position, and using some trigonometric trickery we were able to encode relative positions.

When we're trying to understand these sentences, does it matter that this word is the 2157th word in this blog post? Or do we care about its relationship to the words around it? The absolute position of a word rarely matters for meaning - what matters is how words relate to each other.

Positional encoding in context

From this point on, it's key to consider positional encoding in the context of self attention. To reiterate, the self-attention mechanism enables the model to weigh the importance of different elements in an input sequence and dynamically adjust their influence on the output.

Attn(Q,K,V)=softmax(QKTdk)V \text{Attn}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

In all our previous iterations, we've generated a separate positional encoding vector and added it to our token embedding prior to our QQ, KK and VV projections. By adding the positional information directly to our token embedding, we are polluting the semantic information with the positional information. We should be attempting to encode the information without modifying the norm. Shifting to multiplicative is the key.

Using the dictionary analogy, when looking up a word (query) in our dictionary (keys), nearby words should have more influence than distant ones. The influence of one token upon another is determined by the QKTQK^T dot product - so that's exactly where we should focus our positional encoding!

ab=abcosθ \vec{a} \cdot \vec{b} = |\vec{a}| |\vec{b}| \cos \theta

The geometric interpretation of the dot product shown above gives us a magnificent insight. We can modulate the result of our dot product of two vectors purely by increasing or decreasing the angle between them. Furthermore, by rotating the vector, we have absolutely zero impact on the norm of the vector, which encodes the semantic information of our token.

So now we know where to focus our attention, and have seen from another angle why a rotation might be a sensible "channel" in which to encode our positional information, let's put it all together!

Rotary Postional Encoding

Rotary Postional Encoding or RoPE was defined in the RoFormer paper (Jianlin Su designed it independently on his blog here and here). While it may seem like voodoo if you skip to the end result, by thinking about Sinusoidal Encoding in the context of self attention (and more specifically dot products), we can see how it all comes together.

Much like in Sinusoidal Encoding, we decompose our vectors q\mathbf{q} or k\mathbf{k}, instead of pre-projection x\mathbf{x}) into 2D pairs/chunks. Rather than encoding absolute position directly by adding a vector we drew from sinusoidal functions of slowly decreasing frequencies, we cut to the chase and encode relative position by multiplying each pair with the rotation matrix.

Let q\mathbf{q} or k\mathbf{k} be our input vector at position pp. We create a block diagonal matrix where Mi\mathbf{M_i} is the corresponding rotation matrix for that component pairs desired rotation:

R(q,p)=(M1M2Md/2)(q1q2qd) R(\mathbf{q}, p) = \begin{pmatrix} \mathbf{M_1} & & & \\ & \mathbf{M_2} & & \\ & & \ddots & \\ & & & \mathbf{M_{d/2}} \end{pmatrix} \begin{pmatrix} q_1 \\ q_2 \\ \vdots \\ q_d \end{pmatrix}

Much the same as Sinusoidal Encoding, Mi\mathbf{M_i} is simply:

Mi=[cos(ωip)sin(ωip)sin(ωip)cos(ωip)] \mathbf{M_i} = \begin{bmatrix} \cos(\omega_i p) & \sin(\omega_i p) \\ -\sin(\omega_i p) & \cos(\omega_i p) \end{bmatrix}

In practice, we don't use a matrix multiplication to compute RoPE as it would be computationally inefficient with such a sparse matrix. Instead, we can directly apply the rotations to pairs of elements independently, taking advantage of the regular pattern in the computation:

RΘ,pdq=(q1q2q3q4qd1qd)(cospθ1cospθ1cospθ2cospθ2cospθd/2cospθd/2)+(q2q1q4q3qdqd1)(sinpθ1sinpθ1sinpθ2sinpθ2sinpθd/2sinpθd/2) R_{\Theta,p}^d q = \begin{pmatrix} q_1 \\ q_2 \\ q_3 \\ q_4 \\ \vdots \\ q_{d-1} \\ q_d \end{pmatrix} \otimes \begin{pmatrix} \cos p\theta_1 \\ \cos p\theta_1 \\ \cos p\theta_2 \\ \cos p\theta_2 \\ \vdots \\ \cos p\theta_{d/2} \\ \cos p\theta_{d/2} \end{pmatrix} + \begin{pmatrix} -q_2 \\ q_1 \\ -q_4 \\ q_3 \\ \vdots \\ -q_d \\ q_{d-1} \end{pmatrix} \otimes \begin{pmatrix} \sin p\theta_1 \\ \sin p\theta_1 \\ \sin p\theta_2 \\ \sin p\theta_2 \\ \vdots \\ \sin p\theta_{d/2} \\ \sin p\theta_{d/2} \end{pmatrix}

That's all there is to it! By artfully applying our rotations to 2D chunks of q\mathbf{q} and k\mathbf{k} prior to their dot product, and switching from additive to multiplicative, we can gain a big performance boost in evaluations [4][^4].

Extending RoPE to nn-Dimensions

We've explored the 1D1D case for RoPE and by this point I hope you've gained an intuitive understanding of an admittedly unintuitive component of transformers. Finally, let's explore extending it to higher dimensions, such as images.

A natural first intuition could be to directly use the [xy] \begin{bmatrix} x \\ y \end{bmatrix} coordinate pairs from the image. This might seem intuitive, after all, we were almost arbitrarily pairing up our components previously. However, this would be a mistake!

In the 1D1D case, we encode the relative position mnm - n through a rotation of pairs of values from our input vector. For 2D2D data, we need to encode both horizontal and vertical relative positions, say mnm - n and iji - j independently. RoPE's brilliance lies in how it handles multiple dimensions. Instead of trying to encode all positional information in a single rotation, we pair components within the same dimension and rotate those, otherwise we would be intermixing the xx and yy offset information. By handling each dimension independently, we maintain the natural structure of the space. This can generalize to as many dimensions as required!

The future of positional encoding

Is RoPE the final incarnation of positional encoding? This recent paper from DeepMind deeply analyses RoPE and highlights some fundamental problems. TLDR: RoPE isn't a perfect solution, and the models mostly focus on the lower frequencies and the rotation for a certain percent of low frequencies improves performance on Gemma 2B!

I anticipate some future breakthroughs, perhaps taking inspiration from signal processing with ideas like wavelets or hierarchical implementations. As models are increasingly quantized for deployment, I'd also expect to see some innovation in encoding schemes that remain robust under low-precision arithmetic.

Conclusion

Positional encoding has and continues to be treated as an after thought in transformers. I believe we should view it differently - self attention has an Achilles heel that has been repeatedly patched.

I hope this blog post showed you that you too could have discovered state of the art positional encoding, despite it being unintuitive at first. In a follow up post I'd love to explore practical implementation details for RoPE in order to maximise performance.

This post was originally published here.

References

[^1]: Binary and Sinusoidal animations are reproductions of animations contained in this video.

[^2]: Using θ=10000\theta = 10000 gives us 2π10000 2 \pi \cdot 10000 unique positions, or a theoretical upper bound on the context length at ~63,000.

[^3]: Pieces of this post are based on this fantastic post by Amirhossein Kazemnejad.

[^4]: For empirical evidence, see this great post by EleutherAI.

Community

Sign up or log in to comment