Papers
arxiv:2410.05258

Differential Transformer

Published on Oct 7
ยท Submitted by unilm on Oct 8
#1 Paper of the day
Authors:
,
,

Abstract

Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance large language models.

Community

Paper author Paper submitter

Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance large language models.

Does the differential transformer get rid of the attention sink?

ยท
Paper author

We observe that Diff Transformer allocates less attention scores to attention sinks, i.e., the first few tokens in the sequence.
Specifically, in language modeling task, Diff Transformer allocates less than 5% scores to the BOS token, while Transformer allocates about 25%. For the key information retrieval task, please refer to Figure 1 in the paper. We find that models attend the BOS token more when there is less useful information in the context.

Great stuff. I would love to see comparisons against MรถbiusAttention, which is learns to forget...but this is seems way more computationally efficient.

ยท
Paper author

Thanks for pointing out this paper. We will study into it.

It is a neat approach, but one that comes with a tradeoff, IIUC: doubling the key heads.

I wonder if a different approach without that issue exists. For instance, using max(0, exp(x)-1) instead of exp(x) in the softmax attention formula. That way when the query is orthogonal to the key (or worse), it does not contribute.

ยท
Paper author

In Diff Transformer, we split heads instead of doubling heads. No extra QK projection parameters are introduced. Heads of Q and K are split into two groups and compute in pairs. In a pair they share the same V with dimension 2d. With this design, we match flops and parameter counts with Transformer.
Using max(0, exp(x)-1) might be an approach that solves the problem. We didn't try this because we believe the property of exp() is important to learning.

Great work! Just wonder do you have any idea why two learned attentions tend to cancel noise, rather than canceling signals? For instance, if attention 1 learns S + N_1, and attention 2 learns S + N_2 (where S is signal, N_1, N_2 are different noises), by subtracting these two, the signal S gets canceled while noise becomes N_1 - N_2 which could be more complicated. Is there any reason why the model would not do this instead?

ยท
Paper author
โ€ข
edited Oct 8

It's a good question. Our observation is that the model knows what signal is and what noise is. Notice that attention_1 and attention_2 are both calculated with learnable parameters, they can "perceive" each other in the training process. Then they can adjust themselves according to each other, to achieve lower loss. The result is that the model chooses to preserve signal and cancel out noise as long as we give it the chance to do so. And for a single softmax, it's difficult for it to learn the same solution, due to its formulation and gradient properties.

This is an automated message from the Librarian Bot. I found the following papers similar to this paper.

The following papers were recommended by the Semantic Scholar API

Please give a thumbs up to this comment if you found it helpful!

If you want recommendations for any Paper on Hugging Face checkout this Space

You can directly ask Librarian Bot for paper recommendations by tagging it in a comment: @librarian-bot recommend

Exciting work! Do the authors plan to release model weights on hugging face?

ยท
Paper author

No, we won't release model weights on hugging face.

What if it applies to the Linear Attention acceleration version?

ยท
Paper author

That's an interesting problem to explore. We haven't tried that yet. We will look into it in the future.

Can you provide an intuition why \lambda is re-parameterized as the form shown in the paper?

ยท
Paper author

Sure. lambda is multiplied to softmax, where softmax = exp(qk) / Sigma(exp(qk)). Parameters in lambda learns with the same rate as other parameters in the model, therefore lambda should take a similar formulation as softmax. That's why lambda = exp(lambda_q * lambda_k) + lambda_init. Moreover, to enable lambda to learn values smaller than lambda_init, we add the second term, i.e., lambda = exp(lambda_q1 * lambda_k1) - exp(lambda_q2 * lambda_k2) + lambda_init

What kind of hardware was required to train this, and how did the tokens per second output compare with transformers?

ยท
Paper author

No requirements for hardware if you use the naive implementation. If you use flashdiff, refer to FlashAttention repo (https://github.com/Dao-AILab/flash-attention) for hardware and datatype requirements.
Our speed test is performed on Nvidia H100-80GB GPU cards and we calculate throughput (tokens per second). The same cards and environment are used for both Diff and Transformer.

The work looks exciting and I really like the motivation coming from noise cancellation!
I have a few questions -

  1. Won't this model let the post-attention weight (softmax(...) - \lambda * softmax(...)) for some value vectors be negative? Is that a design choice? One explanation does come to mind i.e. wanting to get opposing contributions from some tokens specifically but I am unsure if this is desired.

  2. This recent work (https://arxiv.org/pdf/2410.01104) shows that attention will disperse given a few conditions (see Lemma 2.1, Page 3). Do you think differential attention is any different? If I understand the proposal correctly, I think it still satisfies Lemma 2.1 with some minor modifications in the proof.

Thanks again for your wonderful work!

ยท
Paper author
โ€ข
edited Oct 10
  1. Yes, there are some negative values in the post-subtraction weight, and that's what we want. The design can expand the representation space of attention weights, which promotes modeling capability. The model is free to allocate positive or negative values to tokens.

  2. If I understand correctly, Diff can break the property in Lemma 2.1. In the paper, Equation 4 points out that values of a single softmax output have a positive lower bound, as input logits can't reach negative infinity. However, by taking the difference of two softmax, the output range includes 0 in it, which means the attention weights is not O(1/n) anymore. This breaks Lemma 2.1. Diff can generate 0 as attention values if it wants and assign it to unwanted context, in the meanwhile, leave almost all attention for key information.

Here's my summary of this paper:

differential_attention.png

โšก๏ธ ๐Œ๐จ๐ฌ๐ญ ๐ข๐ฆ๐ฉ๐จ๐ซ๐ญ๐š๐ง๐ญ ๐›๐ซ๐ž๐š๐ค๐ญ๐ก๐ซ๐จ๐ฎ๐ ๐ก ๐ญ๐ก๐ข๐ฌ ๐ฆ๐จ๐ง๐ญ๐ก: ๐ƒ๐ข๐Ÿ๐Ÿ๐ž๐ซ๐ž๐ง๐ญ๐ข๐š๐ฅ ๐“๐ซ๐š๐ง๐ฌ๐Ÿ๐จ๐ซ๐ฆ๐ž๐ซ ๐ฏ๐š๐ฌ๐ญ๐ฅ๐ฒ ๐ข๐ฆ๐ฉ๐ซ๐จ๐ฏ๐ž๐ฌ ๐š๐ญ๐ญ๐ž๐ง๐ญ๐ข๐จ๐ง โ‡’ ๐›๐ž๐ญ๐ญ๐ž๐ซ ๐ซ๐ž๐ญ๐ซ๐ข๐ž๐ฏ๐š๐ฅ ๐š๐ง๐ ๐Ÿ๐ž๐ฐ๐ž๐ซ ๐ก๐š๐ฅ๐ฅ๐ฎ๐œ๐ข๐ง๐š๐ญ๐ข๐จ๐ง๐ฌ!

Thought that self-attention could not be improved anymore?

Microsoft researchers have dropped a novel "differential attention" mechanism that amplifies focus on relevant context while canceling out noise. It sounds like a free lunch, but it does really seem to vastly improve LLM performance!

๐—ž๐—ฒ๐˜† ๐—ถ๐—ป๐˜€๐—ถ๐—ด๐—ต๐˜๐˜€:

๐Ÿง  Differential attention computes the difference between two separate softmax attention maps, canceling out noise and promoting sparse attention patterns

๐Ÿ”ฅ DIFF Transformer outperforms standard Transformers while using 35-40% fewer parameters or training tokens

๐Ÿ“ Scales well to long contexts up to 64K tokens, leveraging increasing context length more effectively

๐Ÿ”Ž Dramatically improves key information retrieval, enhancing in-context learning, and possibly reducing risk of hallucinations ๐Ÿคฏ

๐Ÿ”ข Reduces activation outliers, potentially enabling lower-bit quantization without performance drop!

โš™๏ธ Can be directly implemented using existing FlashAttention kernels

This new architecture could lead much more capable LLMs, with vastly improved strengths in long-context understanding and factual accuracy.

But they didnโ€™t release weights on the Hub: letโ€™s wait for the community to train the first open-weights DiffTransformer! ๐Ÿš€

This paper was a great read. We wrote a summary blog about this paper and a few more like

  1. TPI LLM
  2. Differential Transformer
  3. ARIA
    You can find it here. Please give it a read :)

I only have 1 burning question about this paper, is this architecture compatible with the attention mechanism method described in "Selective Attention Improves Transformer"?

ยท
Paper author

Hi, we haven't tried to combine them together. Diff Transformer and Selective Attention are proposed from different views and solve different problems. I believe they are compatible.

The proposed approach sounds intriguing. Thanks for your work!

Can you provide any intuition and/or theoretical justification on why vanilla softmax attention fails to deal with noisy tokens in a proper way? Where are the weaknesses in its structure that prevent ignoring irrelevant tokens in a sequence and concentrating on the essential ones?

ยท
Paper author

Hi, you can refer to a recent paper "softmax is not enough for sharp out-of-distribution" (https://arxiv.org/abs/2410.01104).
In simple terms, 1. Softmax can't produce zero scores due to its definition; 2. Producing near-zero scores needs a wide input range which harms backpropagation of softmax. That's why the model fails to cancel out irrelavant tokens.

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2410.05258 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2410.05258 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2410.05258 in a Space README.md to link it from this page.

Collections including this paper 44