sparkleman commited on
Commit
ff3952a
·
1 Parent(s): 05b6df6

UPDATE: Merge cuda core from BlinkDL/RWKV-Gradio-1

Browse files
Files changed (11) hide show
  1. Dockerfile +1 -1
  2. app.py +2 -1
  3. cuda/gemm_fp16_cublas.cpp +75 -0
  4. cuda/operators.cu +246 -0
  5. cuda/rwkv5.cu +88 -0
  6. cuda/rwkv5_op.cpp +34 -0
  7. cuda/rwkv6.cu +87 -0
  8. cuda/rwkv6_op.cpp +34 -0
  9. cuda/wrapper.cpp +141 -0
  10. pyproject.toml +2 -0
  11. uv.lock +28 -0
Dockerfile CHANGED
@@ -26,4 +26,4 @@ COPY --chown=user . $HOME/app
26
 
27
  RUN uv sync --frozen --extra cu124
28
 
29
- CMD ["uv","run","app.py","--strategy","cuda fp16","--model_title","RWKV-x070-World-0.1B-v2.8-20241210-ctx4096","--download_repo_id","BlinkDL/rwkv-7-world","--host","0.0.0.0","--port","7860"]
 
26
 
27
  RUN uv sync --frozen --extra cu124
28
 
29
+ CMD ["uv","run","app.py","--strategy","cuda fp16","--model_title","RWKV-x070-World-0.1B-v2.8-20241210-ctx4096","--download_repo_id","BlinkDL/rwkv-7-world","--host","0.0.0.0","--port","7860","--RWKV_CUDA_ON","True"]
app.py CHANGED
@@ -26,6 +26,7 @@ class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=Tr
26
  description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
27
  )
28
  VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
 
29
 
30
 
31
  CONFIG = Config()
@@ -51,7 +52,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
51
  os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
52
  os.environ["RWKV_JIT_ON"] = "1"
53
  os.environ["RWKV_CUDA_ON"] = (
54
- "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!
55
  )
56
 
57
  from rwkv.model import RWKV
 
26
  description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
27
  )
28
  VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
29
+ RWKV_CUDA_ON:bool = Field(False, description="`True` to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!")
30
 
31
 
32
  CONFIG = Config()
 
52
  os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
53
  os.environ["RWKV_JIT_ON"] = "1"
54
  os.environ["RWKV_CUDA_ON"] = (
55
+ "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
56
  )
57
 
58
  from rwkv.model import RWKV
cuda/gemm_fp16_cublas.cpp ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cublas_v2.h>
2
+ #include <cuda.h>
3
+ #include <cuda_fp16.h>
4
+ #include <cuda_runtime.h>
5
+ #include <torch/extension.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <ATen/cuda/CUDAContext.h>
8
+
9
+ #define CUBLAS_CHECK(condition) \
10
+ for (cublasStatus_t _cublas_check_status = (condition); \
11
+ _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
12
+ throw std::runtime_error("cuBLAS error " + \
13
+ std::to_string(_cublas_check_status) + " at " + \
14
+ std::to_string(__LINE__));
15
+
16
+ #define CUDA_CHECK(condition) \
17
+ for (cudaError_t _cuda_check_status = (condition); \
18
+ _cuda_check_status != cudaSuccess;) \
19
+ throw std::runtime_error( \
20
+ "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
21
+ " at " + std::to_string(__LINE__));
22
+
23
+ /*
24
+ NOTE: blas gemm is column-major by default, but we need row-major output.
25
+ The data of row-major, transposed matrix is exactly the same as the
26
+ column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
27
+ */
28
+ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
29
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
30
+ const auto cuda_data_type = CUDA_R_16F;
31
+ const auto cuda_c_data_type =
32
+ c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
33
+ const auto compute_type = CUDA_R_32F;
34
+ const float sp_alpha = 1.f;
35
+ // swap a and b, and use CUBLAS_OP_N. see the notes above
36
+ std::swap(a, b);
37
+ const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
38
+ const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
39
+ // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
40
+ // negative axis is used because of the existence of batch matmul.
41
+ const int m = a.size(-1);
42
+ const int k = a.size(-2);
43
+ const int n = b.size(-2);
44
+ const int cublas_lda = m;
45
+ const int cublas_ldb = k;
46
+ const int cublas_ldc = m;
47
+ cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
48
+
49
+ #if CUDA_VERSION >= 11000
50
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
51
+ #else
52
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
53
+ #endif
54
+ const float sp_beta = 0.f;
55
+ if (a.sizes().size() == 2 && b.sizes().size() == 2) {
56
+ CUBLAS_CHECK(cublasGemmEx(
57
+ cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
58
+ a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
59
+ cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
60
+ compute_type, algo));
61
+ } else {
62
+ // batch matmul
63
+ assert(a.sizes().size() == 3 && b.sizes().size() == 3);
64
+
65
+ const long long int cublas_stride_a = m * k;
66
+ const long long int cublas_stride_b = k * n;
67
+ const long long int cublas_stride_c = m * n;
68
+ CUBLAS_CHECK(cublasGemmStridedBatchedEx(
69
+ cublas_handle, cublas_trans_a, cublas_trans_b, m,
70
+ n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
71
+ cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
72
+ &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
73
+ a.size(0), compute_type, algo));
74
+ }
75
+ }
cuda/operators.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ #include <cuda_fp16.h>
5
+ #define MIN_VALUE (-1e38)
6
+ typedef at::Half fp16;
7
+ __half *cast(fp16 *ptr) {
8
+ return reinterpret_cast<__half *>(ptr);
9
+ }
10
+
11
+ template <typename F>
12
+ __global__ void kernel_wkv_forward(const int B, const int T, const int C,
13
+ const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
14
+ F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
15
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
16
+ const int _b = idx / C;
17
+ const int _c = idx % C;
18
+ const int _offset = _b * T * C + _c;
19
+ const int _state_offset = _b * C + _c;
20
+
21
+ float u = _u[_c];
22
+ float w = _w[_c];
23
+ const F *__restrict__ const k = _k + _offset;
24
+ const F *__restrict__ const v = _v + _offset;
25
+ F *__restrict__ const y = _y + _offset;
26
+
27
+ float aa = _aa[_state_offset];
28
+ float bb = _bb[_state_offset];
29
+ float pp = _pp[_state_offset];
30
+ for (int i = 0; i < T; i++) {
31
+ const int ii = i * C;
32
+ const float kk = float(k[ii]);
33
+ const float vv = float(v[ii]);
34
+ float ww = u + kk;
35
+ float p = max(pp, ww);
36
+ float e1 = exp(pp - p);
37
+ float e2 = exp(ww - p);
38
+ y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
39
+ ww = w + pp;
40
+ p = max(ww, kk);
41
+ e1 = exp(ww - p);
42
+ e2 = exp(kk - p);
43
+ aa = e1 * aa + e2 * vv;
44
+ bb = e1 * bb + e2;
45
+ pp = p;
46
+ }
47
+ _aa[_state_offset] = aa;
48
+ _bb[_state_offset] = bb;
49
+ _pp[_state_offset] = pp;
50
+ }
51
+
52
+ template <typename F>
53
+ void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
54
+ dim3 threadsPerBlock( min(C, 32) );
55
+ assert(B * C % threadsPerBlock.x == 0);
56
+ dim3 numBlocks(B * C / threadsPerBlock.x);
57
+ kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
58
+ }
59
+
60
+ template void cuda_wkv_forward<fp16>(
61
+ int B, int T, int C,
62
+ float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
63
+ float *aa, float *bb, float *pp);
64
+ template void cuda_wkv_forward<float>(
65
+ int B, int T, int C,
66
+ float *w, float *u, float *k, float *v, float *y,
67
+ float *aa, float *bb, float *pp);
68
+
69
+ __global__ void kernel_mm_seq_fp32i8(
70
+ const int B, const int N, const int M,
71
+ const float *__restrict__ const x, const int x_stride,
72
+ const uint8_t *__restrict__ const w, const int w_stride,
73
+ const float *__restrict__ const mx,
74
+ const float *__restrict__ const rx,
75
+ const float *__restrict__ const my,
76
+ const float *__restrict__ const ry,
77
+ float *__restrict__ const y, const int y_stride) {
78
+
79
+ const int i = blockIdx.x * blockDim.x + threadIdx.x;
80
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
81
+
82
+ if (i < B && k < M) {
83
+ float y_local = 0;
84
+ for (int j = 0; j < N; ++j) {
85
+ y_local += x[i * x_stride + j] * (
86
+ (float(w[j * w_stride + k]) + 0.5f)
87
+ * rx[k] * ry[j] + mx[k] + my[j]
88
+ );
89
+ }
90
+ y[i * y_stride + k] = y_local;
91
+ }
92
+ }
93
+
94
+ template <typename F>
95
+ void cuda_mm8_seq(int B, int N, int M,
96
+ F *x, int x_stride,
97
+ uint8_t *w, int w_stride,
98
+ F *mx, F *rx,
99
+ F *my, F *ry,
100
+ F *y, int y_stride);
101
+
102
+ template <>
103
+ void cuda_mm8_seq<float>(int B, int N, int M,
104
+ float *x, int x_stride,
105
+ uint8_t *w, int w_stride,
106
+ float *mx, float *rx,
107
+ float *my, float *ry,
108
+ float *y, int y_stride) {
109
+ dim3 blockSize(1, 128);
110
+ dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
111
+ kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
112
+ B, N, M, x, x_stride, w, w_stride,
113
+ mx, rx, my, ry, y, y_stride);
114
+ }
115
+
116
+ __global__ void kernel_mm_seq_fp16i8(
117
+ const int B, const int N, const int M,
118
+ const __half *__restrict__ const x, const int x_stride,
119
+ const uint8_t *__restrict__ const w, const int w_stride,
120
+ const __half *__restrict__ const mx,
121
+ const __half *__restrict__ const rx,
122
+ const __half *__restrict__ const my,
123
+ const __half *__restrict__ const ry,
124
+ __half *__restrict__ const y, const int y_stride) {
125
+
126
+ const int i = blockIdx.x * blockDim.x + threadIdx.x;
127
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
128
+
129
+ if (i < B && k < M) {
130
+ float y_local = 0;
131
+ for (int j = 0; j < N; ++j) {
132
+ y_local += __half2float(x[i * x_stride + j]) * (
133
+ (float(w[j * w_stride + k]) + 0.5f)
134
+ * __half2float(rx[k]) * __half2float(ry[j])
135
+ + __half2float(mx[k]) + __half2float(my[j])
136
+ );
137
+ }
138
+ y[i * y_stride + k] = __float2half(y_local);
139
+ }
140
+ }
141
+
142
+ template <>
143
+ void cuda_mm8_seq<fp16>(int B, int N, int M,
144
+ fp16 *x, int x_stride,
145
+ uint8_t *w, int w_stride,
146
+ fp16 *mx, fp16 *rx,
147
+ fp16 *my, fp16 *ry,
148
+ fp16 *y, int y_stride) {
149
+ dim3 blockSize(1, 128);
150
+ dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
151
+ kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
152
+ B, N, M, cast(x), x_stride, w, w_stride,
153
+ cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
154
+ }
155
+
156
+ #define MM8_ONE_JSPLIT 24
157
+ #define MM8_ONE_TILE 1024
158
+
159
+ __global__ void kernel_mm_one_fp32i8(
160
+ const int N, const int M,
161
+ const float *__restrict__ const x,
162
+ const uint8_t *__restrict__ const w, const int w_stride,
163
+ const float *__restrict__ const mx,
164
+ const float *__restrict__ const rx,
165
+ const float *__restrict__ const my,
166
+ const float *__restrict__ const ry,
167
+ float *__restrict__ const y) {
168
+
169
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
170
+ const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
171
+ const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
172
+
173
+ if (k < M) {
174
+ float y_local = 0;
175
+ for (int j = j0; j < j1; ++j) {
176
+ y_local += x[j] * (
177
+ (float(w[j * w_stride + k]) + 0.5f)
178
+ * rx[k] * ry[j] + mx[k] + my[j]
179
+ );
180
+ }
181
+ atomicAdd(&y[k], y_local);
182
+ }
183
+ }
184
+
185
+ template <typename F>
186
+ void cuda_mm8_one(int N, int M,
187
+ F *x,
188
+ uint8_t *w, int w_stride,
189
+ F *mx, F *rx,
190
+ F *my, F *ry,
191
+ float *y);
192
+
193
+ template <>
194
+ void cuda_mm8_one<float>(int N, int M,
195
+ float *x,
196
+ uint8_t *w, int w_stride,
197
+ float *mx, float *rx,
198
+ float *my, float *ry,
199
+ float *y) {
200
+ dim3 blockSize(1, MM8_ONE_TILE);
201
+ dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
202
+ kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
203
+ N, M, x, w, w_stride,
204
+ mx, rx, my, ry, y);
205
+ }
206
+
207
+ __global__ void kernel_mm_one_fp16i8(
208
+ const int N, const int M,
209
+ const __half *__restrict__ const x,
210
+ const uint8_t *__restrict__ const w, const int w_stride,
211
+ const __half *__restrict__ const mx,
212
+ const __half *__restrict__ const rx,
213
+ const __half *__restrict__ const my,
214
+ const __half *__restrict__ const ry,
215
+ float *__restrict__ const y) {
216
+
217
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
218
+ const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
219
+ const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
220
+
221
+ if (k < M) {
222
+ float y_local = 0;
223
+ for (int j = j0; j < j1; ++j) {
224
+ y_local += __half2float(x[j]) * (
225
+ (float(w[j * w_stride + k]) + 0.5f)
226
+ * __half2float(rx[k]) * __half2float(ry[j])
227
+ + __half2float(mx[k]) + __half2float(my[j])
228
+ );
229
+ }
230
+ atomicAdd(&y[k], y_local);
231
+ }
232
+ }
233
+
234
+ template <>
235
+ void cuda_mm8_one<fp16>(int N, int M,
236
+ fp16 *x,
237
+ uint8_t *w, int w_stride,
238
+ fp16 *mx, fp16 *rx,
239
+ fp16 *my, fp16 *ry,
240
+ float *y) {
241
+ dim3 blockSize(1, MM8_ONE_TILE);
242
+ dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
243
+ kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
244
+ N, M, cast(x), w, w_stride,
245
+ cast(mx), cast(rx), cast(my), cast(ry), y);
246
+ }
cuda/rwkv5.cu ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ template <typename F>
9
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
+ F *__restrict__ const _y)
12
+ {
13
+ const int b = blockIdx.x / H;
14
+ const int h = blockIdx.x % H;
15
+ const int i = threadIdx.x;
16
+ _w += h*_N_;
17
+ _u += h*_N_;
18
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
19
+
20
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
21
+
22
+ float state[_N_];
23
+ #pragma unroll
24
+ for (int j = 0; j < _N_; j++)
25
+ state[j] = _state[j];
26
+
27
+ __syncthreads();
28
+ u[i] = float(_u[i]);
29
+ w[i] = _w[i];
30
+ __syncthreads();
31
+
32
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
33
+ {
34
+ __syncthreads();
35
+ r[i] = float(_r[t]);
36
+ k[i] = float(_k[t]);
37
+ __syncthreads();
38
+
39
+ const float v = float(_v[t]);
40
+ float y = 0;
41
+
42
+ #pragma unroll
43
+ for (int j = 0; j < _N_; j+=4)
44
+ {
45
+ const float4& r_ = (float4&)(r[j]);
46
+ const float4& k_ = (float4&)(k[j]);
47
+ const float4& w_ = (float4&)(w[j]);
48
+ const float4& u_ = (float4&)(u[j]);
49
+ float4& s = (float4&)(state[j]);
50
+ float4 x;
51
+
52
+ x.x = k_.x * v;
53
+ x.y = k_.y * v;
54
+ x.z = k_.z * v;
55
+ x.w = k_.w * v;
56
+
57
+ y += r_.x * (u_.x * x.x + s.x);
58
+ y += r_.y * (u_.y * x.y + s.y);
59
+ y += r_.z * (u_.z * x.z + s.z);
60
+ y += r_.w * (u_.w * x.w + s.w);
61
+
62
+ s.x = s.x * w_.x + x.x;
63
+ s.y = s.y * w_.y + x.y;
64
+ s.z = s.z * w_.z + x.z;
65
+ s.w = s.w * w_.w + x.w;
66
+ }
67
+ _y[t] = F(y);
68
+ }
69
+ #pragma unroll
70
+ for (int j = 0; j < _N_; j++)
71
+ _state[j] = state[j];
72
+ }
73
+
74
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
75
+ {
76
+ assert(H*_N_ == C);
77
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
78
+ }
79
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
80
+ {
81
+ assert(H*_N_ == C);
82
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
83
+ }
84
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
85
+ {
86
+ assert(H*_N_ == C);
87
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
88
+ }
cuda/rwkv5_op.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
+ }
16
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
+ }
20
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
+ }
24
+
25
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
+ m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
27
+ m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
28
+ m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
29
+ }
30
+ TORCH_LIBRARY(rwkv5, m) {
31
+ m.def("forward_bf16", forward_bf16);
32
+ m.def("forward_fp16", forward_fp16);
33
+ m.def("forward_fp32", forward_fp32);
34
+ }
cuda/rwkv6.cu ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ template <typename F>
9
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
+ F *__restrict__ const _y)
12
+ {
13
+ const int b = blockIdx.x / H;
14
+ const int h = blockIdx.x % H;
15
+ const int i = threadIdx.x;
16
+ _u += h*_N_;
17
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18
+
19
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
20
+
21
+ float state[_N_];
22
+ #pragma unroll
23
+ for (int j = 0; j < _N_; j++)
24
+ state[j] = _state[j];
25
+
26
+ __syncthreads();
27
+ u[i] = float(_u[i]);
28
+ __syncthreads();
29
+
30
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
31
+ {
32
+ __syncthreads();
33
+ w[i] = _w[t];
34
+ r[i] = float(_r[t]);
35
+ k[i] = float(_k[t]);
36
+ __syncthreads();
37
+
38
+ const float v = float(_v[t]);
39
+ float y = 0;
40
+
41
+ #pragma unroll
42
+ for (int j = 0; j < _N_; j+=4)
43
+ {
44
+ const float4& r_ = (float4&)(r[j]);
45
+ const float4& k_ = (float4&)(k[j]);
46
+ const float4& w_ = (float4&)(w[j]);
47
+ const float4& u_ = (float4&)(u[j]);
48
+ float4& s = (float4&)(state[j]);
49
+ float4 x;
50
+
51
+ x.x = k_.x * v;
52
+ x.y = k_.y * v;
53
+ x.z = k_.z * v;
54
+ x.w = k_.w * v;
55
+
56
+ y += r_.x * (u_.x * x.x + s.x);
57
+ y += r_.y * (u_.y * x.y + s.y);
58
+ y += r_.z * (u_.z * x.z + s.z);
59
+ y += r_.w * (u_.w * x.w + s.w);
60
+
61
+ s.x = s.x * w_.x + x.x;
62
+ s.y = s.y * w_.y + x.y;
63
+ s.z = s.z * w_.z + x.z;
64
+ s.w = s.w * w_.w + x.w;
65
+ }
66
+ _y[t] = F(y);
67
+ }
68
+ #pragma unroll
69
+ for (int j = 0; j < _N_; j++)
70
+ _state[j] = state[j];
71
+ }
72
+
73
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
74
+ {
75
+ assert(H*_N_ == C);
76
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
77
+ }
78
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
79
+ {
80
+ assert(H*_N_ == C);
81
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
82
+ }
83
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
84
+ {
85
+ assert(H*_N_ == C);
86
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
87
+ }
cuda/rwkv6_op.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
+ }
16
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
+ }
20
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
+ }
24
+
25
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
+ m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
27
+ m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
28
+ m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
29
+ }
30
+ TORCH_LIBRARY(rwkv6, m) {
31
+ m.def("forward_bf16", forward_bf16);
32
+ m.def("forward_fp16", forward_fp16);
33
+ m.def("forward_fp32", forward_fp32);
34
+ }
cuda/wrapper.cpp ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <iostream>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+
6
+ typedef at::Half fp16;
7
+
8
+ template <typename F>
9
+ void cuda_wkv_forward(int B, int T, int C,
10
+ float *w, float *u, F *k, F *v, F *y,
11
+ float *aa, float *bb, float *pp);
12
+ template <typename F>
13
+ void cuda_mm8_seq(int B, int N, int M,
14
+ F *x, int x_stride,
15
+ uint8_t *w, int w_stride,
16
+ F *mx, F *rx,
17
+ F *my, F *ry,
18
+ F *y, int y_stride);
19
+ template <typename F>
20
+ void cuda_mm8_one(int N, int M,
21
+ F *x,
22
+ uint8_t *w, int w_stride,
23
+ F *mx, F *rx,
24
+ F *my, F *ry,
25
+ float *y);
26
+
27
+ void wkv_forward(int64_t B, int64_t T, int64_t C,
28
+ torch::Tensor &w, torch::Tensor &u,
29
+ torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
30
+ torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
32
+ switch (k.scalar_type()) {
33
+ case c10::ScalarType::Half:
34
+ cuda_wkv_forward(B, T, C,
35
+ w.data_ptr<float>(), u.data_ptr<float>(),
36
+ k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
37
+ aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
38
+ break;
39
+ case c10::ScalarType::Float:
40
+ cuda_wkv_forward(B, T, C,
41
+ w.data_ptr<float>(), u.data_ptr<float>(),
42
+ k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
43
+ aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
44
+ break;
45
+ default:
46
+ assert(false && "Only FP16 and FP32 are currently supported");
47
+ }
48
+ }
49
+
50
+ void mm8_seq(int64_t B, int64_t N, int64_t M,
51
+ torch::Tensor &x, torch::Tensor &w,
52
+ torch::Tensor &mx, torch::Tensor &rx,
53
+ torch::Tensor &my, torch::Tensor &ry,
54
+ torch::Tensor &y) {
55
+ assert(x.stride(1) == 1);
56
+ assert(w.stride(1) == 1);
57
+ assert(mx.stride(0) == 1 && rx.stride(0) == 1);
58
+ assert(my.stride(0) == 1 && ry.stride(0) == 1);
59
+ assert(y.stride(1) == 1);
60
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
61
+ switch (x.scalar_type()) {
62
+ case c10::ScalarType::Half:
63
+ cuda_mm8_seq(
64
+ B, N, M,
65
+ x.data_ptr<fp16>(), x.stride(0),
66
+ w.data_ptr<uint8_t>(), w.stride(0),
67
+ mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
68
+ my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
69
+ y.data_ptr<fp16>(), y.stride(0));
70
+ break;
71
+ case c10::ScalarType::Float:
72
+ cuda_mm8_seq(
73
+ B, N, M,
74
+ x.data_ptr<float>(), x.stride(0),
75
+ w.data_ptr<uint8_t>(), w.stride(0),
76
+ mx.data_ptr<float>(), rx.data_ptr<float>(),
77
+ my.data_ptr<float>(), ry.data_ptr<float>(),
78
+ y.data_ptr<float>(), y.stride(0));
79
+ break;
80
+ default:
81
+ assert(false && "Only FP16 and FP32 are currently supported");
82
+ }
83
+ }
84
+ void mm8_one(int64_t N, int64_t M,
85
+ torch::Tensor &x, torch::Tensor &w,
86
+ torch::Tensor &mx, torch::Tensor &rx,
87
+ torch::Tensor &my, torch::Tensor &ry,
88
+ torch::Tensor &y) {
89
+ assert(x.stride(0) == 1);
90
+ assert(w.stride(1) == 1);
91
+ assert(mx.stride(0) == 1 && rx.stride(0) == 1);
92
+ assert(my.stride(0) == 1 && ry.stride(0) == 1);
93
+ assert(y.stride(0) == 1);
94
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
95
+ switch (x.scalar_type()) {
96
+ case c10::ScalarType::Half:
97
+ cuda_mm8_one(
98
+ N, M,
99
+ x.data_ptr<fp16>(),
100
+ w.data_ptr<uint8_t>(), w.stride(0),
101
+ mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
102
+ my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
103
+ y.data_ptr<float>());
104
+ break;
105
+ case c10::ScalarType::Float:
106
+ cuda_mm8_one(
107
+ N, M,
108
+ x.data_ptr<float>(),
109
+ w.data_ptr<uint8_t>(), w.stride(0),
110
+ mx.data_ptr<float>(), rx.data_ptr<float>(),
111
+ my.data_ptr<float>(), ry.data_ptr<float>(),
112
+ y.data_ptr<float>());
113
+ break;
114
+ default:
115
+ assert(false && "Only FP16 and FP32 are currently supported");
116
+ }
117
+ }
118
+
119
+ using torch::Tensor;
120
+
121
+ #ifndef DISABLE_CUBLAS_GEMM
122
+ void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
123
+ #endif
124
+
125
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
126
+ m.def("wkv_forward", &wkv_forward, "wkv forward");
127
+ m.def("mm8_seq", &mm8_seq, "mm8 seq");
128
+ m.def("mm8_one", &mm8_one, "mm8 one");
129
+ #ifndef DISABLE_CUBLAS_GEMM
130
+ m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
131
+ #endif
132
+ }
133
+
134
+ TORCH_LIBRARY(rwkv, m) {
135
+ m.def("wkv_forward", wkv_forward);
136
+ m.def("mm8_seq", mm8_seq);
137
+ m.def("mm8_one", mm8_one);
138
+ #ifndef DISABLE_CUBLAS_GEMM
139
+ m.def("gemm_fp16_cublas", gemm_fp16_cublas);
140
+ #endif
141
+ }
pyproject.toml CHANGED
@@ -8,11 +8,13 @@ dependencies = [
8
  "fastapi[standard]>=0.115.11",
9
  "huggingface-hub>=0.29.1",
10
  "loguru>=0.7.3",
 
11
  "numpy>=2.2.3",
12
  "pydantic>=2.10.6",
13
  "pydantic-settings>=2.8.1",
14
  "pynvml>=12.0.0",
15
  "rwkv==0.8.28",
 
16
  "snowflake-id>=1.0.2",
17
  ]
18
 
 
8
  "fastapi[standard]>=0.115.11",
9
  "huggingface-hub>=0.29.1",
10
  "loguru>=0.7.3",
11
+ "ninja>=1.11.1.3",
12
  "numpy>=2.2.3",
13
  "pydantic>=2.10.6",
14
  "pydantic-settings>=2.8.1",
15
  "pynvml>=12.0.0",
16
  "rwkv==0.8.28",
17
+ "setuptools>=75.8.2",
18
  "snowflake-id>=1.0.2",
19
  ]
20
 
uv.lock CHANGED
@@ -446,6 +446,30 @@ wheels = [
446
  { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 },
447
  ]
448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  [[package]]
450
  name = "numpy"
451
  version = "2.2.3"
@@ -915,11 +939,13 @@ dependencies = [
915
  { name = "fastapi", extra = ["standard"] },
916
  { name = "huggingface-hub" },
917
  { name = "loguru" },
 
918
  { name = "numpy" },
919
  { name = "pydantic" },
920
  { name = "pydantic-settings" },
921
  { name = "pynvml" },
922
  { name = "rwkv" },
 
923
  { name = "snowflake-id" },
924
  ]
925
 
@@ -940,11 +966,13 @@ requires-dist = [
940
  { name = "fastapi", extras = ["standard"], specifier = ">=0.115.11" },
941
  { name = "huggingface-hub", specifier = ">=0.29.1" },
942
  { name = "loguru", specifier = ">=0.7.3" },
 
943
  { name = "numpy", specifier = ">=2.2.3" },
944
  { name = "pydantic", specifier = ">=2.10.6" },
945
  { name = "pydantic-settings", specifier = ">=2.8.1" },
946
  { name = "pynvml", specifier = ">=12.0.0" },
947
  { name = "rwkv", specifier = "==0.8.28" },
 
948
  { name = "snowflake-id", specifier = ">=1.0.2" },
949
  { name = "torch", marker = "extra == 'cpu'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "rwkv-hf-space", extra = "cpu" } },
950
  { name = "torch", marker = "extra == 'cu113'", index = "https://download.pytorch.org/whl/cu113", conflict = { package = "rwkv-hf-space", extra = "cu113" } },
 
446
  { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 },
447
  ]
448
 
449
+ [[package]]
450
+ name = "ninja"
451
+ version = "1.11.1.3"
452
+ source = { registry = "https://pypi.org/simple" }
453
+ sdist = { url = "https://files.pythonhosted.org/packages/bd/8f/21a2701f95b7d0d5137736561b3427ece0c4a1e085d4a223b92d16ab7d8b/ninja-1.11.1.3.tar.gz", hash = "sha256:edfa0d2e9d7ead1635b03e40a32ad56cc8f56798b6e2e9848d8300b174897076", size = 129532 }
454
+ wheels = [
455
+ { url = "https://files.pythonhosted.org/packages/ea/ba/0069cd4a83d68f7b0308be70e219b15d675e50c8ea28763a3f0373c45bfc/ninja-1.11.1.3-py3-none-macosx_10_9_universal2.whl", hash = "sha256:2b4879ea3f1169f3d855182c57dcc84d1b5048628c8b7be0d702b81882a37237", size = 279132 },
456
+ { url = "https://files.pythonhosted.org/packages/72/6b/3805be87df8417a0c7b21078c8045f2a1e59b34f371bfe4cb4fb0d6df7f2/ninja-1.11.1.3-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:bc3ebc8b2e47716149f3541742b5cd8e0b08f51013b825c05baca3e34854370d", size = 472101 },
457
+ { url = "https://files.pythonhosted.org/packages/6b/35/a8e38d54768e67324e365e2a41162be298f51ec93e6bd4b18d237d7250d8/ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a27e78ca71316c8654965ee94b286a98c83877bfebe2607db96897bbfe458af0", size = 422884 },
458
+ { url = "https://files.pythonhosted.org/packages/2f/99/7996457319e139c02697fb2aa28e42fe32bb0752cef492edc69d56a3552e/ninja-1.11.1.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2883ea46b3c5079074f56820f9989c6261fcc6fd873d914ee49010ecf283c3b2", size = 157046 },
459
+ { url = "https://files.pythonhosted.org/packages/6d/8b/93f38e5cddf76ccfdab70946515b554f25d2b4c95ef9b2f9cfbc43fa7cc1/ninja-1.11.1.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c4bdb9fd2d0c06501ae15abfd23407660e95659e384acd36e013b6dd7d8a8e4", size = 180014 },
460
+ { url = "https://files.pythonhosted.org/packages/7d/1d/713884d0fa3c972164f69d552e0701d30e2bf25eba9ef160bfb3dc69926a/ninja-1.11.1.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:114ed5c61c8474df6a69ab89097a20749b769e2c219a452cb2fadc49b0d581b0", size = 157098 },
461
+ { url = "https://files.pythonhosted.org/packages/c7/22/ecb0f70e77c9e22ee250aa717a608a142756833a34d43943d7d658ee0e56/ninja-1.11.1.3-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7fa2247fce98f683bc712562d82b22b8a0a5c000738a13147ca2d1b68c122298", size = 130089 },
462
+ { url = "https://files.pythonhosted.org/packages/ec/a6/3ee846c20ab6ad95b90c5c8703c76cb1f39cc8ce2d1ae468956e3b1b2581/ninja-1.11.1.3-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:a38c6c6c8032bed68b70c3b065d944c35e9f903342875d3a3218c1607987077c", size = 372508 },
463
+ { url = "https://files.pythonhosted.org/packages/95/0d/aa44abe4141f29148ce671ac8c92045878906b18691c6f87a29711c2ff1c/ninja-1.11.1.3-py3-none-musllinux_1_1_i686.whl", hash = "sha256:56ada5d33b8741d298836644042faddebc83ee669782d661e21563034beb5aba", size = 419369 },
464
+ { url = "https://files.pythonhosted.org/packages/f7/ec/48bf5105568ac9bd2016b701777bdd5000cc09a14ac837fef9f15e8d634e/ninja-1.11.1.3-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:53409151da081f3c198bb0bfc220a7f4e821e022c5b7d29719adda892ddb31bb", size = 420304 },
465
+ { url = "https://files.pythonhosted.org/packages/18/e5/69df63976cf971a03379899f8520a036c9dbab26330b37197512aed5b3df/ninja-1.11.1.3-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:1ad2112c2b0159ed7c4ae3731595191b1546ba62316fc40808edecd0306fefa3", size = 416056 },
466
+ { url = "https://files.pythonhosted.org/packages/6f/4f/bdb401af7ed0e24a3fef058e13a149f2de1ce4b176699076993615d55610/ninja-1.11.1.3-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:28aea3c1c280cba95b8608d50797169f3a34280e3e9a6379b6e340f0c9eaeeb0", size = 379725 },
467
+ { url = "https://files.pythonhosted.org/packages/bd/68/05e7863bf13128c61652eeb3ec7096c3d3a602f32f31752dbfb034e3fa07/ninja-1.11.1.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b6966f83064a88a51693073eea3decd47e08c3965241e09578ef7aa3a7738329", size = 434881 },
468
+ { url = "https://files.pythonhosted.org/packages/bd/ad/edc0d1efe77f29f45bbca2e1dab07ef597f61a88de6e4bccffc0aec2256c/ninja-1.11.1.3-py3-none-win32.whl", hash = "sha256:a4a3b71490557e18c010cbb26bd1ea9a0c32ee67e8f105e9731515b6e0af792e", size = 255988 },
469
+ { url = "https://files.pythonhosted.org/packages/03/93/09a9f7672b4f97438aca6217ac54212a63273f1cd3b46b731d0bb22c53e7/ninja-1.11.1.3-py3-none-win_amd64.whl", hash = "sha256:04d48d14ea7ba11951c156599ab526bdda575450797ff57c6fdf99b2554d09c7", size = 296502 },
470
+ { url = "https://files.pythonhosted.org/packages/d9/9d/0cc1e82849070ff3cbee69f326cb48a839407bcd15d8844443c30a5e7509/ninja-1.11.1.3-py3-none-win_arm64.whl", hash = "sha256:17978ad611d8ead578d83637f5ae80c2261b033db0b493a7ce94f88623f29e1b", size = 270571 },
471
+ ]
472
+
473
  [[package]]
474
  name = "numpy"
475
  version = "2.2.3"
 
939
  { name = "fastapi", extra = ["standard"] },
940
  { name = "huggingface-hub" },
941
  { name = "loguru" },
942
+ { name = "ninja" },
943
  { name = "numpy" },
944
  { name = "pydantic" },
945
  { name = "pydantic-settings" },
946
  { name = "pynvml" },
947
  { name = "rwkv" },
948
+ { name = "setuptools" },
949
  { name = "snowflake-id" },
950
  ]
951
 
 
966
  { name = "fastapi", extras = ["standard"], specifier = ">=0.115.11" },
967
  { name = "huggingface-hub", specifier = ">=0.29.1" },
968
  { name = "loguru", specifier = ">=0.7.3" },
969
+ { name = "ninja", specifier = ">=1.11.1.3" },
970
  { name = "numpy", specifier = ">=2.2.3" },
971
  { name = "pydantic", specifier = ">=2.10.6" },
972
  { name = "pydantic-settings", specifier = ">=2.8.1" },
973
  { name = "pynvml", specifier = ">=12.0.0" },
974
  { name = "rwkv", specifier = "==0.8.28" },
975
+ { name = "setuptools", specifier = ">=75.8.2" },
976
  { name = "snowflake-id", specifier = ">=1.0.2" },
977
  { name = "torch", marker = "extra == 'cpu'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "rwkv-hf-space", extra = "cpu" } },
978
  { name = "torch", marker = "extra == 'cu113'", index = "https://download.pytorch.org/whl/cu113", conflict = { package = "rwkv-hf-space", extra = "cu113" } },