KV Caching Explained: Optimizing Transformer Inference Efficiency

Community Article Published January 30, 2025

Introduction

When AI models generate text, they often repeat many of the same calculations, which can slow things down. Key-Value caching is a technique that helps speed up this process by remembering important information from previous steps. Instead of recomputing everything from scratch, the model reuses what it has already calculated, making text generation much faster and more efficient.

In this blogpost, we’ll break down KV caching in an easy-to-understand way, explain why it’s useful, and show how it helps AI models work faster.

Prerequisites

To fully grasp the content, readers should be familiar with:

  1. Transformer Architecture: Familiarity with components such as the attention mechanism.
  2. Autoregressive Modeling: Understanding of how models like GPT generate sequences.
  3. Linear Algebra Basics: Concepts like matrix multiplication and transposition, which are essential for understanding attention computations.

This πŸ‘‰ BLOG should cover up most of the prerequisites needed for this article.

click here for some of the most essential takeaways.
  • attention weight has a shape of [batch,h,Seqlen,Seqlen] [\text{batch}, h, \mathrm{Seq}_{\mathrm{len}}, \mathrm{Seq}_{\mathrm{len}}]
  • masked multi-head attention allows each token to be represented by itself and all the previous tokens.
  • to generate a new token the model needs to look at all the previous tokens and their representations by their preceding tokens

Standard Inference and the Rise of KV Caching

When a model generates text, it looks at all the previous tokens to predict the next one. Normally, it would repeat the same calculations for every new token, which can slow things down.

KV caching solves compute overlap by remembering these calculations from previous steps, this can be achieved by storing the intermediate states of attention layers during inference.

How Does KV Caching Work?

Step-by-Step Process

  1. First Generation: When the model sees the first input, it calculates and stores its keys and values in the cache. ⇓ \Downarrow
  2. Next Words: For each new word, the model retrieves the stored keys and values and adds the new ones instead of starting over.
  3. Efficient Attention Computation: calculate attention using the cached KK and VV along with the new QQ (query) to compute the output.
  4. Update Input: add the newly generated token to the input and go back to step 2\texttt{go back to step 2} until we finish generating.

The process is illustrated below:

Token 1: [K1, V1] βž” Cache: [K1, V1]
Token 2: [K2, V2] βž” Cache: [K1, K2], [V1, V2]
...
Token n: [Kn, Vn] βž” Cache: [K1, K2, ..., Kn], [V1, V2, ..., Vn]
KV Caching Standard Inference

In the table above we used a dk=5d_k = 5 for better visuals, note that this number can be significantly bigger than what we have presented.

Comparison: KV Caching vs. Standard Inference

Here’s how KV caching compares to the regular generations :

Feature Standard Inference KV Caching
Computation per Word The model repeats the same calculations for every word. The model reuses past calculations for faster results.
Memory Usage Uses less memory at each step, but memory grows with longer texts. Uses extra memory to store past information, but keeps things efficient.
Speed Gets slower as the text gets longer because it repeats work. Stays fast even with longer texts by avoiding repeated work.
Efficiency High computational cost and slower response times. Faster and more efficient since the model remembers past work.
Handling Long Texts Struggles with long texts due to repeated calculations. Perfect for long texts as it remembers past steps.

KV caching makes a big difference in speed and efficiency, especially for long texts. By saving and reusing past calculations, it avoids the need to start over each time, making it much faster than the regular way of generating text.

Practical Implementation

This is a simplified example of implementing KV caching in PyTorch:

# Pseudocode for KV Caching in PyTorch
class KVCache:
    def __init__(self):
        self.cache = {"key": None, "value": None}

    def update(self, key, value):
        if self.cache["key"] is None:
            self.cache["key"] = key
            self.cache["value"] = value
        else:
            self.cache["key"] = torch.cat([self.cache["key"], key], dim=1)
            self.cache["value"] = torch.cat([self.cache["value"], value], dim=1)

    def get_cache(self):
        return self.cache

When using the transformers library this behavior is enabled by default through the use_cache parameter, you can also access multiple caching methods through the cache_implementation parameter, here's a minimalistic code :

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-1.7B')
model = AutoModelForCausalLM.from_pretrained('HuggingFaceTB/SmolLM2-1.7B').cuda()

tokens = tokenizer.encode("The red cat was", return_tensors="pt").cuda()
output = model.generate(
    tokens, max_new_tokens=300, use_cache = True # by default is set to True
)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]

We benchmarked the code above with/without kv caching on a T4 GPU we got the following results :

with KV Caching Standard Inference Speedup
11.7 s 1min 1s ~5.21x times faster

Conclusion

KV caching is a simple but powerful technique that helps AI models generate text faster and more efficiently. By remembering past calculations instead of repeating them, it reduces the time and effort needed to predict new words. While it does require extra memory, this method is especially useful for long conversations ensuring fast and efficient generation.

Understanding KV caching can help developers and AI enthusiasts build faster, smarter, and more scalable language models for real-world applications.

I would like to extend my sincerest gratitude to Aritra Roy Gosthipaty πŸ€— for his invaluable support, feedback, and dedication in developing this blog post.

References & Further Reading

  1. Transformers KV Caching Explained
  2. Transformers Key-Value Caching Explained
  3. Mastering LLM Techniques: Inference Optimization
  4. Hugging Face Documentation - KV Caching in Transformers

Community

Sign up or log in to comment