kernel
sae

Flex SAE Kernels

ArXiv

Fused Triton implementations of the TopK and HierarchicalTopK sparse autoencoder (SAE) decoder losses described in Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy.

This work has been accepted to EMNLP 2025.

What is released?

  • Fast TopK kernel for SAE (slightly modified version from xformers) torch-ext/flex_sae/topk_kernels.py
  • Fast HierarchicalTopK kernels (see our paper) torch-ext/flex_sae/hierarchical_kernels.py.

Quickstart

Kernels are available via loading from hub, they have the following signature:

from kernels import get_kernel


flex = get_kernel('t-tech/flex-sae')

top_k_kernel = flex.triton_topk_sae_loss
hierarchical_top_k_kernel = flex.triton_hierarchical_sae_loss

"B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"

loss: torch.Tensor = top_k_kernel(
    indices: torch.Tensor,  # [B, K]
    weight: torch.Tensor,  # [F, D]
    vals: torch.Tensor,  # [B, K]
    bias: torch.Tensor,  # [D]
    target: torch.Tensor,  # [B, D]
)

loss: torch.Tensor = hierarchical_top_k_kernel(
    indices: torch.Tensor,  # [B, K]
    weight: torch.Tensor,  # [F, D]
    vals: torch.Tensor,  # [B, K]
    bias: torch.Tensor,  # [D]
    target: torch.Tensor,  # [B, D]
)

Overview

  • torch-ext/flex_sae/ contains the Triton kernels alongside torch reference implementations.
  • tests/ hosts CUDA-backed property tests that ensure numerical parity across dtypes and kernels.
  • build.toml, flake.nix integrate the project with Hugging Face kernel-builder.

The Triton kernels target CUDA GPUs and focus on reducing the latency gap between TopK and HierarchicalTopK decoders while keeping memory usage flat.

Example

You can find example usage in example.py.

# /// script
# dependencies = [
#   "torch",
#   "numpy",
#   "kernels",
# ]
# ///

import torch
import numpy as np
from kernels import get_kernel

flex = get_kernel("t-tech/flex-sae")  #Fast Kernels

@torch.compile(fullgraph=True)
def hierarchical_sae_loss(
    indices: torch.Tensor,  # [B, K]
    weight: torch.Tensor,  # [F, D]
    vals: torch.Tensor,  # [B, K]
    bias: torch.Tensor,  # [D]
    target: torch.Tensor,  # [B, D]
) -> torch.Tensor:
    emb = weight[indices].to(torch.float32)  # [K, D]
    recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
    diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
    loss = diff.pow(2).mean()
    return loss


B = 2048
K = 256
F = 1024 * 128
D = 1024
WARMUP = 5
NUM_ITER = 100
dtype = torch.float32

vals = None
decoder = None
bias = None
target = None
indices = None


def init_parameters():
    global vals, decoder, bias, target, indices
    vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
    decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
    bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
    target = torch.randn(B, D, dtype=dtype, device="cuda")
    indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")


timing_kernel = []
timing_vanilla = []
torch.cuda.reset_peak_memory_stats()
loss_kernel_list = torch.zeros((100,))
loss_vanilla_list = torch.zeros((100,))


def zero_grad():
    vals.grad = None
    decoder.grad = None
    bias.grad = None
    torch.cuda.empty_cache()


for i in range(NUM_ITER + WARMUP):
    init_parameters()
    start_kernel = torch.cuda.Event(enable_timing=True)
    end_kernel = torch.cuda.Event(enable_timing=True)
    start_vanilla = torch.cuda.Event(enable_timing=True)
    end_vanilla = torch.cuda.Event(enable_timing=True)

    start_kernel.record()
    loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
    loss_kernel.backward()
    end_kernel.record()

    zero_grad()
    start_vanilla.record()
    loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
    loss_vanilla.backward()
    end_vanilla.record()
    if i >= WARMUP:
        torch.cuda.synchronize()
        timing_kernel.append(start_kernel.elapsed_time(end_kernel))
        timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
        loss_kernel_list[i-WARMUP] = loss_kernel.detach()
        loss_vanilla_list[i-WARMUP] = loss_vanilla.detach()
    zero_grad()

if torch.allclose(loss_kernel, loss_vanilla):
    print("βœ… Outputs are close! Everything is good! πŸŽ‰")
else:
    print("❌ Outputs mismatch... βš οΈπŸ€”")


print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
print(f"πŸ”₯ Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} Β± {np.std(timing_vanilla):.4f} ms")
print(f"πŸš€ Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")

Run it with uv run https://huggingface.co/t-tech/flex-sae/resolve/main/example.py.

Performance

Benchmarks were collected on a workload with dictionary size $F = 65 536$, embedding dimension $D = 2304$, and sparsity budgets $K \in {32, 64, 128}$. Latency is reported as time per training step (milliseconds) and memory as peak device usage (GiB).

Decoder backend K=32 (ms / GiB) K=64 (ms / GiB) K=128 (ms / GiB)
Pure torch-compiled
TopK 8.787 / 2.92 11.746 / 2.92 18.877 / 2.93
HierarchicalTopK 12.824 / 6.29 23.379 / 10.79 43.851 / 19.80
Triton kernels
TopK 5.576 / 2.92 6.339 / 2.92 7.961 / 2.93
HierarchicalTopK 6.696 / 2.92 7.995 / 2.92 10.609 / 2.93

Across the evaluated sparsity budgets the fused Triton HierarchicalTopK kernel matches TopK kernels on memory use while remaining consistently faster than the reference torch implementation.

License & Attribution

  • All files except torch-ext/flex_sae/topk_kernels.py are released under the Apache License 2.0.
  • torch-ext/flex_sae/topk_kernels.py includes code adapted from Facebook Research's memory project, originally published under the Creative Commons Attribution-NonCommercial 4.0 International License. That component therefore remains available for non-commercial use only; see NOTICE for details.

Citation

@misc{balagansky2025trainsparseautoencodermultiple,
      title={Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy},
      author={Nikita Balagansky and Yaroslav Aksenov and Daniil Laptev and Vadim Kurochkin and Gleb Gerasimov and Nikita Koryagin and Daniil Gavrilov},
      year={2025},
      eprint={2505.24473},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2505.24473},
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Collection including t-tech/flex-sae