Text Generation
Transformers
Safetensors
Chinese
English
minicpm
feature-extraction
conversational
custom_code
Instructions to use openbmb/MiniCPM4-MCP with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use openbmb/MiniCPM4-MCP with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="openbmb/MiniCPM4-MCP", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("openbmb/MiniCPM4-MCP", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use openbmb/MiniCPM4-MCP with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "openbmb/MiniCPM4-MCP" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "openbmb/MiniCPM4-MCP", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/openbmb/MiniCPM4-MCP
- SGLang
How to use openbmb/MiniCPM4-MCP with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "openbmb/MiniCPM4-MCP" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "openbmb/MiniCPM4-MCP", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "openbmb/MiniCPM4-MCP" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "openbmb/MiniCPM4-MCP", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use openbmb/MiniCPM4-MCP with Docker Model Runner:
docker model run hf.co/openbmb/MiniCPM4-MCP
| # coding=utf-8 | |
| # Copyright 2025 The OpenBMB Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import Any, Tuple, Union | |
| from collections import Counter | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| import warnings | |
| from torch import nn | |
| def is_hopper_gpu(): | |
| if torch.cuda.is_available(): | |
| device_capability = torch.cuda.get_device_capability() | |
| major, minor = device_capability | |
| return major == 9 | |
| return False | |
| def get_compressed_seqlens( | |
| cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int | |
| ): | |
| # compute seqlens after compression | |
| seqlens = cu_seqlens[1:] - cu_seqlens[:-1] | |
| y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 | |
| # corner case, if sequence_length < kernel_size, no compression for this sequence | |
| y_seqlens[seqlens < kernel_size] = 0 | |
| y_cu_seqlens = torch.zeros( | |
| y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device | |
| ) | |
| y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0) | |
| return y_seqlens, y_cu_seqlens | |
| def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): | |
| """ | |
| Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. | |
| Args: | |
| head_dim (int): Size of the head dimension. | |
| block_size (int): Size of the block in the attention matrix. | |
| is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. | |
| Returns: | |
| tuple: (num_warps, num_stages) recommended values. | |
| """ | |
| # Determine if head_dim and block_size exceed 64 | |
| head_large = head_dim > 64 | |
| block_large = block_size > 64 | |
| if is_hopper_gpu: | |
| # Hopper GPU recommendations | |
| if head_large and block_large: | |
| num_warps = 8 | |
| num_stages = 3 | |
| elif head_large or block_large: | |
| num_warps = 4 | |
| num_stages = 3 | |
| else: | |
| num_warps = 2 | |
| num_stages = 2 | |
| else: | |
| # Ampere GPU recommendations | |
| if head_large and block_large: | |
| num_warps = 8 | |
| num_stages = 3 | |
| elif head_large or block_large: | |
| num_warps = 8 | |
| num_stages = 3 | |
| else: | |
| num_warps = 2 | |
| num_stages = 2 | |
| return num_warps, num_stages | |
| IS_HOPPER_GPU = is_hopper_gpu() | |
| def forward_kernel( | |
| q_ptr, # Q: n x h x d | |
| k_ptr, # K: n x h x d | |
| v_ptr, # V: n x h x d | |
| o_ptr, # O: n x h x d | |
| lse_ptr, # LSE: h x n | |
| # size and stride at compresstion | |
| kernel_size, | |
| kernel_stride, | |
| # seqlens | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| # shape | |
| NUM_KV_HEADS, | |
| NUM_SHARE_Q_HEADS, | |
| HEAD_DIM, | |
| # sm_scale | |
| sm_scale, | |
| # stride | |
| stride_qn, | |
| stride_qh, | |
| stride_qd, | |
| stride_kn, | |
| stride_kh, | |
| stride_kd, | |
| stride_vn, | |
| stride_vh, | |
| stride_vd, | |
| stride_on, | |
| stride_oh, | |
| stride_od, | |
| stride_lh, | |
| stride_ln, | |
| # META parameters | |
| BLOCK_SIZE_Q: tl.constexpr, # q block size | |
| BLOCK_SIZE_K: tl.constexpr, # k block size | |
| BLOCK_SIZE_D: tl.constexpr, | |
| ): | |
| qk_scale = sm_scale * 1.44269504 | |
| # get batch id and head id | |
| pid_b = tl.program_id(0) | |
| pid_h = tl.program_id(1) | |
| pid_q = tl.program_id(2) | |
| pid_kh = pid_h // NUM_SHARE_Q_HEADS | |
| # get q k start and len after rmpad | |
| q_start = tl.load(cu_seqlens_q + pid_b) | |
| q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start | |
| k_start = tl.load(cu_seqlens_k + pid_b) | |
| k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start | |
| # skip first kernel_size query block, because they do no attend to any keys | |
| q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 | |
| if q_start_in_seq >= q_len: | |
| return | |
| # init qkv pointer | |
| q_ptrs = tl.make_block_ptr( | |
| base=q_ptr + q_start * stride_qn + pid_h * stride_qh, | |
| shape=(q_len, HEAD_DIM), | |
| strides=(stride_qn, stride_qd), | |
| offsets=(q_start_in_seq, 0), | |
| block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| k_ptrs = tl.make_block_ptr( | |
| base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, | |
| shape=(HEAD_DIM, k_len), | |
| strides=(stride_kd, stride_kn), | |
| offsets=(0, 0), | |
| block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), | |
| order=(0, 1), | |
| ) | |
| v_ptrs = tl.make_block_ptr( | |
| base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, | |
| shape=(k_len, HEAD_DIM), | |
| strides=(stride_vn, stride_vd), | |
| offsets=(0, 0), | |
| block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| # load q | |
| q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| # init statistics | |
| off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq | |
| off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 | |
| m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) | |
| lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) | |
| acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) | |
| # attention | |
| lo = 0 | |
| hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) | |
| for i in range(lo, hi, BLOCK_SIZE_K): | |
| i = tl.multiple_of(i, BLOCK_SIZE_K) | |
| # load k | |
| k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") | |
| # compute qk | |
| qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) | |
| qk += tl.where( | |
| off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf") | |
| ) | |
| qk += tl.dot(q, k) * qk_scale | |
| # compute m_ij and l_ij | |
| m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) | |
| p = tl.exp2(qk - m_ij[:, None]) | |
| l_ij = tl.sum(p, axis=1) | |
| # scale acc_o | |
| acc_o_scale = tl.exp2(m_i - m_ij) | |
| acc_o = acc_o * acc_o_scale[:, None] | |
| # load v and update acc_o | |
| v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| p = p.to(v.dtype) | |
| acc_o += tl.dot(p, v) | |
| # update statistics | |
| m_i = m_ij | |
| lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) | |
| # update ptrs | |
| k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) | |
| v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) | |
| # final scale | |
| acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] | |
| # save output | |
| o_ptrs = tl.make_block_ptr( | |
| base=o_ptr + q_start * stride_on + pid_h * stride_oh, | |
| shape=(q_len, HEAD_DIM), | |
| strides=(stride_on, stride_od), | |
| offsets=(q_start_in_seq, 0), | |
| block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) | |
| # save lse | |
| l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln | |
| tl.store(l_ptrs, lse_i, mask=off_q < q_len) | |
| def backward_sum_o_do( | |
| o_ptr, # O: n x h x d | |
| do_ptr, # dO: n x h x d | |
| delta_ptr, # D: h x n | |
| o_len, | |
| HEAD_DIM, | |
| stride_on, | |
| stride_oh, | |
| stride_od, | |
| stride_don, | |
| stride_doh, | |
| stride_dod, | |
| stride_dh, | |
| stride_dn, | |
| BLOCK_SIZE_O: tl.constexpr, | |
| BLOCK_SIZE_D: tl.constexpr, | |
| ): | |
| pid_n = tl.program_id(0) | |
| pid_h = tl.program_id(1) | |
| off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) | |
| off_d = tl.arange(0, BLOCK_SIZE_D) | |
| o = tl.load( | |
| o_ptr | |
| + off_n[:, None] * stride_on | |
| + pid_h * stride_oh | |
| + off_d[None, :] * stride_od, | |
| mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), | |
| other=0, | |
| ).to(tl.float32) | |
| do = tl.load( | |
| do_ptr | |
| + off_n[:, None] * stride_don | |
| + pid_h * stride_doh | |
| + off_d[None, :] * stride_dod, | |
| mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), | |
| other=0, | |
| ).to(tl.float32) | |
| delta = tl.sum(o * do, axis=1) | |
| tl.store( | |
| delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len | |
| ) | |
| def backward_dkdv( | |
| q_ptr, # Q: n x qh x d | |
| k_ptr, # K: n x kh x d | |
| v_ptr, # V: n x kh x d | |
| lse_ptr, # LSE: qh x n | |
| d_ptr, # Delta: qh x n | |
| do_ptr, | |
| dk_ptr, # DK: sh x n x kh x d | |
| dv_ptr, # DV: sh x n x kh x d | |
| kernel_size, | |
| kernel_stride, | |
| # seqlens | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| # shape | |
| NUM_KV_HEADS, | |
| NUM_SHARE_Q_HEADS, | |
| HEAD_DIM, | |
| # sm_scale | |
| sm_scale, | |
| # stride | |
| stride_qn, | |
| stride_qh, | |
| stride_qd, | |
| stride_kn, | |
| stride_kh, | |
| stride_kd, | |
| stride_vn, | |
| stride_vh, | |
| stride_vd, | |
| stride_lh, | |
| stride_ln, | |
| stride_dh, | |
| stride_dn, | |
| stride_don, | |
| stride_doh, | |
| stride_dod, | |
| stride_dks, | |
| stride_dkn, | |
| stride_dkh, | |
| stride_dkd, | |
| stride_dvs, | |
| stride_dvn, | |
| stride_dvh, | |
| stride_dvd, | |
| # META parameters | |
| BLOCK_SIZE_Q: tl.constexpr, # q block size | |
| BLOCK_SIZE_K: tl.constexpr, # k block size | |
| BLOCK_SIZE_D: tl.constexpr, | |
| ): | |
| qk_scale = sm_scale * 1.44269504 | |
| # get batch id and head id | |
| pid_b = tl.program_id(0) | |
| pid_h = tl.program_id(1) | |
| pid_kh = pid_h // NUM_SHARE_Q_HEADS | |
| pid_sh = pid_h % NUM_SHARE_Q_HEADS | |
| pid_k = tl.program_id(2) | |
| # get q k start and len after rmpad | |
| q_start = tl.load(cu_seqlens_q + pid_b) | |
| q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start | |
| k_start = tl.load(cu_seqlens_k + pid_b) | |
| k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start | |
| if BLOCK_SIZE_K * pid_k >= k_len: | |
| return | |
| # init pointers | |
| k_ptrs = tl.make_block_ptr( | |
| base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, | |
| shape=(k_len, HEAD_DIM), | |
| strides=(stride_kn, stride_kd), | |
| offsets=(pid_k * BLOCK_SIZE_K, 0), | |
| block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| dk_ptrs = tl.make_block_ptr( | |
| base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, | |
| shape=(k_len, HEAD_DIM), | |
| strides=(stride_dkn, stride_dkd), | |
| offsets=(pid_k * BLOCK_SIZE_K, 0), | |
| block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| v_ptrs = tl.make_block_ptr( | |
| base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, | |
| shape=(k_len, HEAD_DIM), | |
| strides=(stride_vn, stride_vd), | |
| offsets=(pid_k * BLOCK_SIZE_K, 0), | |
| block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| dv_ptrs = tl.make_block_ptr( | |
| base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, | |
| shape=(k_len, HEAD_DIM), | |
| strides=(stride_dvn, stride_dvd), | |
| offsets=(pid_k * BLOCK_SIZE_K, 0), | |
| block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| # offsets | |
| off_q = tl.arange(0, BLOCK_SIZE_Q) | |
| off_k = ( | |
| pid_k * BLOCK_SIZE_K * kernel_stride | |
| + tl.arange(0, BLOCK_SIZE_K) * kernel_stride | |
| + kernel_size | |
| - 1 | |
| ) | |
| # load k v and keep in SRAM | |
| k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| # init dk dv | |
| dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) | |
| dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) | |
| q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1 | |
| q_ptrs = tl.make_block_ptr( | |
| base=q_ptr + q_start * stride_qn + pid_h * stride_qh, | |
| shape=(HEAD_DIM, q_len), | |
| strides=(stride_qd, stride_qn), | |
| offsets=(0, q_lo), | |
| block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), | |
| order=(0, 1), | |
| ) | |
| do_ptrs = tl.make_block_ptr( | |
| base=do_ptr + q_start * stride_don + pid_h * stride_doh, | |
| shape=(HEAD_DIM, q_len), | |
| strides=(stride_dod, stride_don), | |
| offsets=(0, q_lo), | |
| block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), | |
| order=(0, 1), | |
| ) | |
| d_ptrs = tl.make_block_ptr( | |
| base=d_ptr + q_start * stride_dn + pid_h * stride_dh, | |
| shape=(1, q_len), | |
| strides=(0, stride_dn), | |
| offsets=(0, q_lo), | |
| block_shape=(1, BLOCK_SIZE_Q), | |
| order=(1, 0), | |
| ) | |
| lse_ptrs = tl.make_block_ptr( | |
| base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, | |
| shape=(1, q_len), | |
| strides=(0, stride_ln), | |
| offsets=(0, q_lo), | |
| block_shape=(1, BLOCK_SIZE_Q), | |
| order=(0, 1), | |
| ) | |
| # loop for q blocks | |
| for i in range(q_lo, q_len, BLOCK_SIZE_Q): | |
| # load | |
| q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| # compute qk | |
| # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] | |
| qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf")) | |
| qk += tl.dot(k, q) * qk_scale | |
| # compute p, ds | |
| # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] | |
| p = tl.exp2(qk - lse) | |
| # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] | |
| dp = tl.dot(v, do) | |
| ds = sm_scale * p * (dp - d) | |
| # cast dtype | |
| p = p.to(do.dtype) | |
| ds = ds.to(q.dtype) | |
| # update dk and dv | |
| # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM] | |
| dk += tl.dot(ds, tl.trans(q)) | |
| dv += tl.dot(p, tl.trans(do)) | |
| # increment pointers | |
| q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q)) | |
| do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q)) | |
| lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q)) | |
| d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q)) | |
| # save dk dv | |
| tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) | |
| tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) | |
| def backward_dq( | |
| q_ptr, # Q: n x qh x d | |
| k_ptr, # K: n x kh x d | |
| v_ptr, # V: n x kh x d | |
| lse_ptr, # LSE: qh x n | |
| d_ptr, # Delta: qh x n | |
| do_ptr, | |
| dq_ptr, | |
| kernel_size, | |
| kernel_stride, | |
| # seqlens | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| # shape | |
| NUM_KV_HEADS, | |
| NUM_SHARE_Q_HEADS, | |
| HEAD_DIM, | |
| # sm_scale | |
| sm_scale, | |
| # stride | |
| stride_qn, | |
| stride_qh, | |
| stride_qd, | |
| stride_kn, | |
| stride_kh, | |
| stride_kd, | |
| stride_vn, | |
| stride_vh, | |
| stride_vd, | |
| stride_lh, | |
| stride_ln, | |
| stride_dh, | |
| stride_dn, | |
| stride_don, | |
| stride_doh, | |
| stride_dod, | |
| stride_dqn, | |
| stride_dqh, | |
| stride_dqd, | |
| # META parameters | |
| BLOCK_SIZE_Q: tl.constexpr, # q block size | |
| BLOCK_SIZE_K: tl.constexpr, # k block size | |
| BLOCK_SIZE_D: tl.constexpr, | |
| ): | |
| qk_scale = sm_scale * 1.44269504 | |
| # get batch id and head id | |
| pid_b = tl.program_id(0) | |
| pid_h = tl.program_id(1) | |
| pid_q = tl.program_id(2) | |
| pid_kh = pid_h // NUM_SHARE_Q_HEADS | |
| # get q k start and len after rmpad | |
| q_start = tl.load(cu_seqlens_q + pid_b) | |
| q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start | |
| k_start = tl.load(cu_seqlens_k + pid_b) | |
| k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start | |
| # skip first kernel_size query block, because they do no attend to any keys | |
| q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 | |
| if q_start_in_seq >= q_len: | |
| return | |
| # init pointers | |
| q_ptrs = tl.make_block_ptr( | |
| base=q_ptr + q_start * stride_qn + pid_h * stride_qh, | |
| shape=(q_len, HEAD_DIM), | |
| strides=(stride_qn, stride_qd), | |
| offsets=(q_start_in_seq, 0), | |
| block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| dq_ptrs = tl.make_block_ptr( | |
| base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, | |
| shape=(q_len, HEAD_DIM), | |
| strides=(stride_dqn, stride_dqd), | |
| offsets=(q_start_in_seq, 0), | |
| block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| k_ptrs = tl.make_block_ptr( | |
| base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, | |
| shape=(k_len, HEAD_DIM), | |
| strides=(stride_kn, stride_kd), | |
| offsets=(0, 0), | |
| block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| v_ptrs = tl.make_block_ptr( | |
| base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, | |
| shape=(HEAD_DIM, k_len), | |
| strides=(stride_vd, stride_vn), | |
| offsets=(0, 0), | |
| block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), | |
| order=(0, 1), | |
| ) | |
| do_ptrs = tl.make_block_ptr( | |
| base=do_ptr + q_start * stride_don + pid_h * stride_doh, | |
| shape=(q_len, HEAD_DIM), | |
| strides=(stride_don, stride_dod), | |
| offsets=(q_start_in_seq, 0), | |
| block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| d_ptrs = tl.make_block_ptr( | |
| base=d_ptr + q_start * stride_dn + pid_h * stride_dh, | |
| shape=(q_len, 1), | |
| strides=(stride_dn, stride_dh), | |
| offsets=(q_start_in_seq, 0), | |
| block_shape=(BLOCK_SIZE_Q, 1), | |
| order=(0, 1), | |
| ) | |
| lse_ptrs = tl.make_block_ptr( | |
| base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, | |
| shape=(q_len, 1), | |
| strides=(stride_ln, stride_lh), | |
| offsets=(q_start_in_seq, 0), | |
| block_shape=(BLOCK_SIZE_Q, 1), | |
| order=(0, 1), | |
| ) | |
| # offsets | |
| off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq | |
| off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 | |
| # load q, do, lse, delta, and keep in SRAM | |
| q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") | |
| do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| # init dq | |
| dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) | |
| lo = 0 | |
| hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) | |
| for i in range(lo, hi, BLOCK_SIZE_K): | |
| # load | |
| k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| # compute qk | |
| qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) | |
| qk += tl.where( | |
| off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf") | |
| ) | |
| qk += tl.dot(q, tl.trans(k)) * qk_scale | |
| # compute p, ds | |
| p = tl.exp2(qk - lse) | |
| dp = tl.dot(do, v) | |
| ds = sm_scale * p * (dp - d) | |
| # cast dtype | |
| ds = ds.to(q.dtype) | |
| # update dq | |
| dq += tl.dot(ds, k) | |
| # increment pointers | |
| k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) | |
| v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K)) | |
| # save dq | |
| tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) | |
| def _compressed_attention_fwd( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| kernel_size: int, | |
| kernel_stride: int, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| max_seqlen_q: torch.Tensor, | |
| max_seqlen_k: torch.Tensor, | |
| sm_scale: float, | |
| ): | |
| # dtype check | |
| assert k.dtype == q.dtype and v.dtype == q.dtype | |
| assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 | |
| # shape | |
| q_len, num_q_heads, head_dim = q.shape | |
| k_len, num_k_heads, head_dim = k.shape | |
| v_len, num_v_heads, head_dim = v.shape | |
| batch_size = cu_seqlens_q.shape[0] - 1 | |
| assert k_len == v_len and q_len > k_len | |
| # gqa | |
| assert num_k_heads == num_v_heads | |
| assert num_q_heads % num_k_heads == 0 | |
| num_share_q_heads = num_q_heads // num_k_heads | |
| # output tensor | |
| o = torch.zeros_like(q) | |
| lse = torch.full( | |
| (num_q_heads, q_len), | |
| fill_value=-torch.inf, | |
| dtype=torch.float32, | |
| device=q.device, | |
| ) | |
| # launch kernel | |
| grid = lambda META: ( | |
| batch_size, | |
| num_q_heads, | |
| triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), | |
| ) | |
| BLOCK_SIZE_Q = 128 | |
| BLOCK_SIZE_K = 128 | |
| BLOCK_SIZE_D = triton.next_power_of_2(head_dim) | |
| num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) | |
| forward_kernel[grid]( | |
| q, | |
| k, | |
| v, | |
| o, | |
| lse, | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| num_k_heads, | |
| num_share_q_heads, | |
| head_dim, | |
| sm_scale, | |
| q.stride(0), | |
| q.stride(1), | |
| q.stride(2), | |
| k.stride(0), | |
| k.stride(1), | |
| k.stride(2), | |
| v.stride(0), | |
| v.stride(1), | |
| v.stride(2), | |
| o.stride(0), | |
| o.stride(1), | |
| o.stride(2), | |
| lse.stride(0), | |
| lse.stride(1), | |
| BLOCK_SIZE_Q=BLOCK_SIZE_Q, | |
| BLOCK_SIZE_K=BLOCK_SIZE_K, | |
| BLOCK_SIZE_D=BLOCK_SIZE_D, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| ) | |
| return o, lse | |
| def _compressed_attention_bwd( | |
| o: torch.Tensor, | |
| do: torch.Tensor, | |
| lse: torch.Tensor, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| kernel_size: int, | |
| kernel_stride: int, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| max_seqlen_q: torch.Tensor, | |
| max_seqlen_k: torch.Tensor, | |
| sm_scale: float, | |
| ): | |
| q_len, num_q_heads, head_dim = q.shape | |
| k_len, num_k_heads, head_dim = k.shape | |
| v_len, num_v_heads, head_dim = v.shape | |
| o_len, num_o_heads, head_dim = o.shape | |
| num_share_q_heads = num_q_heads // num_k_heads | |
| # compute D | |
| delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) | |
| grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) | |
| BLOCK_SIZE_O = 256 | |
| BLOCK_SIZE_D = triton.next_power_of_2(head_dim) | |
| num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) | |
| backward_sum_o_do[grid]( | |
| o, | |
| do, | |
| delta, | |
| o_len, | |
| head_dim, | |
| o.stride(0), | |
| o.stride(1), | |
| o.stride(2), | |
| do.stride(0), | |
| do.stride(1), | |
| do.stride(2), | |
| delta.stride(0), | |
| delta.stride(1), | |
| BLOCK_SIZE_O=BLOCK_SIZE_O, | |
| BLOCK_SIZE_D=BLOCK_SIZE_D, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| ) | |
| # compute dk dv | |
| dk = torch.zeros( | |
| num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype | |
| ) | |
| dv = torch.zeros( | |
| num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype | |
| ) | |
| batch_size = cu_seqlens_q.shape[0] - 1 | |
| grid = lambda META: ( | |
| batch_size, | |
| num_q_heads, | |
| triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), | |
| ) | |
| BLOCK_SIZE_Q = 64 | |
| BLOCK_SIZE_K = 128 | |
| BLOCK_SIZE_D = triton.next_power_of_2(head_dim) | |
| num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) | |
| backward_dkdv[grid]( | |
| q, | |
| k, | |
| v, | |
| lse, | |
| delta, | |
| do, | |
| dk, | |
| dv, | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| num_k_heads, | |
| num_share_q_heads, | |
| head_dim, | |
| sm_scale, | |
| q.stride(0), | |
| q.stride(1), | |
| q.stride(2), | |
| k.stride(0), | |
| k.stride(1), | |
| k.stride(2), | |
| v.stride(0), | |
| v.stride(1), | |
| v.stride(2), | |
| lse.stride(0), | |
| lse.stride(1), | |
| delta.stride(0), | |
| delta.stride(1), | |
| do.stride(0), | |
| do.stride(1), | |
| do.stride(2), | |
| dk.stride(0), | |
| dk.stride(1), | |
| dk.stride(2), | |
| dk.stride(3), | |
| dv.stride(0), | |
| dv.stride(1), | |
| dv.stride(2), | |
| dv.stride(3), | |
| BLOCK_SIZE_Q=BLOCK_SIZE_Q, | |
| BLOCK_SIZE_K=BLOCK_SIZE_K, | |
| BLOCK_SIZE_D=BLOCK_SIZE_D, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| ) | |
| dk = dk.sum(0) | |
| dv = dv.sum(0) | |
| # compute dq | |
| dq = torch.zeros_like(q) | |
| grid = lambda META: ( | |
| batch_size, | |
| num_q_heads, | |
| triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), | |
| ) | |
| BLOCK_SIZE_Q = 128 | |
| BLOCK_SIZE_K = 64 | |
| num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) | |
| backward_dq[grid]( | |
| q, | |
| k, | |
| v, | |
| lse, | |
| delta, | |
| do, | |
| dq, | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| num_k_heads, | |
| num_share_q_heads, | |
| head_dim, | |
| sm_scale, | |
| q.stride(0), | |
| q.stride(1), | |
| q.stride(2), | |
| k.stride(0), | |
| k.stride(1), | |
| k.stride(2), | |
| v.stride(0), | |
| v.stride(1), | |
| v.stride(2), | |
| lse.stride(0), | |
| lse.stride(1), | |
| delta.stride(0), | |
| delta.stride(1), | |
| do.stride(0), | |
| do.stride(1), | |
| do.stride(2), | |
| dq.stride(0), | |
| dq.stride(1), | |
| dq.stride(2), | |
| BLOCK_SIZE_Q=BLOCK_SIZE_Q, | |
| BLOCK_SIZE_K=BLOCK_SIZE_K, | |
| BLOCK_SIZE_D=BLOCK_SIZE_D, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| ) | |
| return dq, dk, dv | |
| class CompressedAttention(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| kernel_size: int, | |
| kernel_stride: int, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| max_seqlen_q: torch.Tensor, | |
| max_seqlen_k: torch.Tensor, | |
| sm_scale=None, | |
| ): | |
| # dtype check | |
| assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 | |
| assert q.dtype == k.dtype and k.dtype == v.dtype | |
| assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 | |
| # softmax scale | |
| if sm_scale is None: | |
| sm_scale = 1 / math.sqrt(q.shape[-1]) | |
| o, lse = _compressed_attention_fwd( | |
| q, | |
| k, | |
| v, | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| sm_scale, | |
| ) | |
| ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) | |
| ctx.sm_scale = sm_scale | |
| ctx.max_seqlen_q = max_seqlen_q | |
| ctx.max_seqlen_k = max_seqlen_k | |
| ctx.kernel_size = kernel_size | |
| ctx.kernel_stride = kernel_stride | |
| return o, lse | |
| def backward(ctx, do: torch.Tensor, *args) -> Any: | |
| q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors | |
| max_seqlen_q = ctx.max_seqlen_q | |
| max_seqlen_k = ctx.max_seqlen_k | |
| sm_scale = ctx.sm_scale | |
| kernel_size = ctx.kernel_size | |
| kernel_stride = ctx.kernel_stride | |
| dq, dk, dv = _compressed_attention_bwd( | |
| o, | |
| do, | |
| lse, | |
| q, | |
| k, | |
| v, | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| sm_scale, | |
| ) | |
| return dq, dk, dv, None, None, None, None, None, None, None | |
| def score_kernel( | |
| q_ptr, | |
| k_ptr, | |
| lse_ptr, | |
| s_ptr, | |
| kernel_size, | |
| kernel_stride, | |
| # seqlens | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| # shape | |
| NUM_KV_HEADS, | |
| NUM_SHARE_Q_HEADS, | |
| HEAD_DIM, | |
| # sm_scale | |
| sm_scale, | |
| # stride | |
| stride_qn, | |
| stride_qh, | |
| stride_qd, | |
| stride_kn, | |
| stride_kh, | |
| stride_kd, | |
| stride_lh, | |
| stride_ln, | |
| stride_sh, | |
| stride_sq, | |
| stride_sk, | |
| # META parameters | |
| BLOCK_SIZE_Q: tl.constexpr, # q block size | |
| BLOCK_SIZE_K: tl.constexpr, # k block size | |
| BLOCK_SIZE_D: tl.constexpr, | |
| ): | |
| qk_scale = sm_scale * 1.44269504 | |
| # get batch id and head id | |
| pid_bkh = tl.program_id(0) | |
| pid_b = pid_bkh // NUM_KV_HEADS | |
| pid_kh = pid_bkh % NUM_KV_HEADS | |
| pid_q = tl.program_id(1) | |
| pid_k = tl.program_id(2) | |
| # get q k start and len after rmpad | |
| q_start = tl.load(cu_seqlens_q + pid_b) | |
| q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start | |
| k_start = tl.load(cu_seqlens_k + pid_b) | |
| k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start | |
| if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: | |
| return | |
| # init k pointer and load k | |
| k_ptrs = tl.make_block_ptr( | |
| base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, | |
| shape=(HEAD_DIM, k_len), | |
| strides=(stride_kd, stride_kn), | |
| offsets=(0, pid_k * BLOCK_SIZE_K), | |
| block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), | |
| order=(0, 1), | |
| ) | |
| k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| # offsets | |
| off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q | |
| off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K | |
| causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] | |
| # init score | |
| s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) | |
| # loop over gqa heads | |
| for h in range(NUM_SHARE_Q_HEADS): | |
| pid_h = pid_kh * NUM_SHARE_Q_HEADS + h | |
| q_ptrs = tl.make_block_ptr( | |
| base=q_ptr + q_start * stride_qn + pid_h * stride_qh, | |
| shape=(q_len, HEAD_DIM), | |
| strides=(stride_qn, stride_qd), | |
| offsets=(pid_q * BLOCK_SIZE_Q, 0), | |
| block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), | |
| order=(1, 0), | |
| ) | |
| lse_ptrs = tl.make_block_ptr( | |
| base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, | |
| shape=(q_len, 1), | |
| strides=(stride_ln, stride_lh), | |
| offsets=(pid_q * BLOCK_SIZE_Q, 0), | |
| block_shape=(BLOCK_SIZE_Q, 1), | |
| order=(0, 1), | |
| ) | |
| # load q and lse | |
| q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") | |
| # compute qk | |
| qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) | |
| qk += tl.dot(q, k) * qk_scale | |
| # compute score | |
| s += tl.where(causal_mask, tl.exp2(qk - lse), 0) | |
| # save output | |
| s_ptrs = tl.make_block_ptr( | |
| base=s_ptr + pid_kh * stride_sh + q_start * stride_sq, | |
| shape=(q_len, k_len), | |
| strides=(stride_sq, stride_sk), | |
| offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K), | |
| block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K), | |
| order=(1, 0), | |
| ) | |
| tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1)) | |
| def _get_attention_score( | |
| q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] | |
| k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] | |
| lse: torch.Tensor, # [num_q_heads, total_query_len] | |
| kernel_size: int, | |
| kernel_stride: int, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| max_seqlen_q: int, | |
| max_seqlen_k: int, | |
| sm_scale: float, | |
| ) -> torch.Tensor: | |
| # dtype check | |
| assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 | |
| assert q.dtype == k.dtype | |
| assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 | |
| assert ( | |
| lse.dtype == torch.float32 | |
| ) # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) | |
| # shape | |
| q_len, num_q_heads, head_dim = q.shape | |
| k_len, num_k_heads, head_dim = k.shape | |
| batch_size = cu_seqlens_q.shape[0] - 1 | |
| assert q_len > k_len | |
| if sm_scale is None: | |
| sm_scale = 1 / math.sqrt(head_dim) | |
| # gqa | |
| assert num_q_heads % num_k_heads == 0 | |
| num_share_q_heads = num_q_heads // num_k_heads | |
| # init score | |
| score = torch.zeros( | |
| num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device | |
| ) | |
| # launch kernel | |
| grid = lambda META: ( | |
| batch_size * num_k_heads, | |
| triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), | |
| triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), | |
| ) | |
| BLOCK_SIZE_Q = 128 | |
| BLOCK_SIZE_K = 128 | |
| BLOCK_SIZE_D = triton.next_power_of_2(head_dim) | |
| score_kernel[grid]( | |
| q, | |
| k, | |
| lse, | |
| score, | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| num_k_heads, | |
| num_share_q_heads, | |
| head_dim, | |
| sm_scale, | |
| q.stride(0), | |
| q.stride(1), | |
| q.stride(2), | |
| k.stride(0), | |
| k.stride(1), | |
| k.stride(2), | |
| lse.stride(0), | |
| lse.stride(1), | |
| score.stride(0), | |
| score.stride(1), | |
| score.stride(2), | |
| BLOCK_SIZE_Q=BLOCK_SIZE_Q, | |
| BLOCK_SIZE_K=BLOCK_SIZE_K, | |
| BLOCK_SIZE_D=BLOCK_SIZE_D, | |
| num_warps=8, | |
| num_stages=3, | |
| ) | |
| return score | |
| def _transform_score_kernel( | |
| s_ptr, # score, shape: [num_heads, q_len, k_len] | |
| bs_ptr, # block wise score: [num_heads, q_len, num_k_block] | |
| offs, | |
| cu_seqlens_q, | |
| # shape | |
| num_heads, | |
| num_offs, | |
| max_k_len, | |
| max_blocks, | |
| pad_len, | |
| # kernel & block size | |
| block_size, | |
| block_stride, # block_size // kernel_stride | |
| init_blocks, | |
| local_blocks, | |
| # stride | |
| stride_sh, | |
| stride_sq, | |
| stride_sk, | |
| stride_bsh, | |
| stride_bsq, | |
| stride_bsk, | |
| BLOCK_SIZE_Q: tl.constexpr, | |
| BLOCK_SIZE_K: tl.constexpr, | |
| BLOCK_SIZE_O: tl.constexpr, | |
| ): | |
| pid_bh = tl.program_id(0) | |
| pid_b = pid_bh // num_heads | |
| pid_h = pid_bh % num_heads | |
| pid_q = tl.program_id(1) | |
| pid_k = tl.program_id(2) | |
| q_start = tl.load(cu_seqlens_q + pid_b) | |
| q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start | |
| k_start = pid_k * BLOCK_SIZE_K | |
| if pid_q * BLOCK_SIZE_Q >= q_len: | |
| return | |
| # load weight | |
| off_o = tl.arange(0, BLOCK_SIZE_O) | |
| w = tl.load(offs + off_o, mask=off_o < num_offs, other=0) | |
| # load score | |
| off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) | |
| off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len | |
| off_k = off_k[None, :] + off_o[:, None] | |
| s_ptrs = ( | |
| s_ptr | |
| + q_start * stride_sq | |
| + pid_h * stride_sh | |
| + off_q[:, None, None] * stride_sq | |
| + off_k[None, :, :] * stride_sk | |
| ) | |
| # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK] | |
| s = tl.load( | |
| s_ptrs, | |
| mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len), | |
| other=0, | |
| ) | |
| s = s * w[None, :, None] | |
| s = tl.max(s, axis=1) | |
| # init mask and local mask | |
| off_bq = off_q // block_size | |
| off_bk = tl.arange(0, BLOCK_SIZE_K) | |
| s = tl.where( | |
| # For local blocks: set to negative infinity (exclude from topk) | |
| (off_bq[:, None] >= (off_bk + k_start)[None, :]) & (off_bq[:, None] < (off_bk + k_start)[None, :] + local_blocks), | |
| float("-inf"), | |
| s, | |
| ) | |
| # Keep the original conditions for init_blocks and query location as infinity | |
| s = tl.where( | |
| (off_bk[None, :] < init_blocks - k_start) | |
| # Force blocks where the query is located to have infinite score (always include in topk) | |
| | (off_bq[:, None] == (off_bk + k_start)[None, :]), | |
| float("inf"), | |
| s, | |
| ) | |
| # store block wise score | |
| bs_ptrs = ( | |
| bs_ptr | |
| + q_start * stride_bsq | |
| + k_start * stride_bsk | |
| + pid_h * stride_bsh | |
| + off_q[:, None] * stride_bsq | |
| + off_bk[None, :] * stride_bsk | |
| ) | |
| tl.store( | |
| bs_ptrs, | |
| s, | |
| mask=(off_q < q_len)[:, None] & (off_bk < max_blocks - k_start)[None, :], | |
| ) | |
| def transform_score( | |
| score: torch.Tensor, | |
| kernel_size: int, | |
| kernel_stride: int, | |
| block_size: int, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| max_seqlen_q: int, | |
| max_seqlen_k: int, | |
| init_blocks: int = 1, | |
| local_blocks: int = 2, | |
| ) -> torch.Tensor: | |
| num_k_heads, total_query_len, max_key_len = score.shape | |
| batch_size = cu_seqlens_q.shape[0] - 1 | |
| pad_len = kernel_size // kernel_stride - 1 | |
| max_blocks = math.ceil(max_seqlen_q / block_size) | |
| block_score = torch.zeros( | |
| num_k_heads, | |
| total_query_len, | |
| max_blocks, | |
| dtype=torch.float32, | |
| device=score.device, | |
| ) | |
| offs = ( | |
| torch.arange(kernel_size // kernel_stride, device=score.device)[:, None] | |
| + torch.arange(block_size // kernel_stride, device=score.device)[None, :] | |
| ).view(-1) | |
| offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) | |
| num_offs = int(offs.shape[0]) | |
| BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) | |
| BLOCK_SIZE_O = triton.next_power_of_2(num_offs) | |
| BLOCK_SIZE_Q = 8 | |
| grid = ( | |
| num_k_heads * batch_size, | |
| triton.cdiv(total_query_len, BLOCK_SIZE_Q), | |
| triton.cdiv(max_blocks, BLOCK_SIZE_K), | |
| ) | |
| _transform_score_kernel[grid]( | |
| score, | |
| block_score, | |
| torch.ones_like(offs, dtype=offs.dtype,device=offs.device), #! 为了max 就不用wieght了 | |
| cu_seqlens_q, | |
| num_k_heads, | |
| offs.shape[0], | |
| max_key_len, | |
| max_blocks, | |
| pad_len, | |
| block_size, | |
| block_size // kernel_stride, | |
| init_blocks, | |
| local_blocks, | |
| score.stride(0), | |
| score.stride(1), | |
| score.stride(2), | |
| block_score.stride(0), | |
| block_score.stride(1), | |
| block_score.stride(2), | |
| BLOCK_SIZE_Q=BLOCK_SIZE_Q, | |
| BLOCK_SIZE_K=BLOCK_SIZE_K, | |
| BLOCK_SIZE_O=BLOCK_SIZE_O, | |
| num_warps=8, | |
| num_stages=3, | |
| ) | |
| return block_score | |
| def compressed_attention( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| kernel_size: int, | |
| kernel_stride: int, | |
| block_size: int, | |
| topk: int, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| max_seqlen_q: int, | |
| max_seqlen_k: int, | |
| sm_scale: float = None, | |
| init_blocks: int = 1, | |
| local_blocks: int = 2, | |
| parallel_topk_compute: Union[str, bool] = "auto", | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention. | |
| Args: | |
| q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] | |
| k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] | |
| v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] | |
| kernel_size (int): kernel size in compress_key_value | |
| kernel_stride (int): stride of compress_key_value | |
| block_size (int): key value block size for topk sparse attention. | |
| topk (int): number of blocks for each query. | |
| cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. | |
| cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. | |
| max_seqlen_q (int): max q len of the batch. | |
| max_seqlen_k (int): max k len of the batch. | |
| sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). | |
| init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. | |
| local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. | |
| parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug. | |
| We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention | |
| """ | |
| if max_seqlen_q is None: | |
| max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() | |
| if max_seqlen_k is None: | |
| max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() | |
| attn_output, lse = CompressedAttention.apply( | |
| q, | |
| k, | |
| v, | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| sm_scale, | |
| ) | |
| # do not select topk index | |
| if topk <= 0: | |
| warnings.warn("topk <= 0, returned topk_idx will be None") | |
| return attn_output, None | |
| assert topk >= init_blocks #+ local_blocks | |
| with torch.no_grad(): | |
| num_k_heads, num_q_heads = k.shape[1], q.shape[1] | |
| num_shared_q_heads = num_q_heads // num_k_heads | |
| batch_size = cu_seqlens_q.shape[0] - 1 | |
| q_idx = torch.cat( | |
| [ | |
| torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) | |
| for i in range(batch_size) | |
| ], | |
| dim=0, | |
| ) | |
| q_idx = q_idx // block_size | |
| # whether to use parallel version | |
| if parallel_topk_compute == "auto": | |
| parallel_topk_compute = cu_seqlens_q[-1] <= 32768 | |
| # parallel version | |
| if parallel_topk_compute: | |
| # recompute score | |
| score = _get_attention_score( | |
| q, | |
| k, | |
| lse, | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| sm_scale, | |
| ) | |
| # transform score to block-wise score | |
| score = transform_score( | |
| score, | |
| kernel_size, | |
| kernel_stride, | |
| block_size, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| init_blocks, | |
| local_blocks, | |
| ) | |
| # get topk | |
| topk = min(topk, score.shape[-1]) | |
| topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values | |
| # print(cu_seqlens_q) | |
| # breakpoint() | |
| topk_idx[topk_idx >= q_idx[None, :, None]] = -1 | |
| topk_idx = topk_idx.to(torch.int32) | |
| # non parallel version, avoid some current bugs when sequence length is too long | |
| # FIXME: need to fix later | |
| else: | |
| topk_idx_list = [] | |
| for h in range(num_k_heads): | |
| # recompute score | |
| score = _get_attention_score( | |
| q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads], | |
| k[:, h : h + 1], | |
| lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads], | |
| kernel_size, | |
| kernel_stride, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| sm_scale, | |
| ) | |
| # transform score to block-wise score | |
| score = transform_score( | |
| score, | |
| kernel_size, | |
| kernel_stride, | |
| block_size, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| init_blocks, | |
| local_blocks, | |
| ) | |
| # get topk | |
| topk = min(topk, score.shape[-1]) | |
| topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values | |
| topk_idx[topk_idx >= q_idx[None, :, None]] = -1 | |
| topk_idx = topk_idx.to(torch.int32) | |
| topk_idx_list.append(topk_idx) | |
| topk_idx = torch.cat(topk_idx_list, dim=0) | |
| return attn_output, topk_idx | |