picocreator's picture
Upload 13 files
33b8599 verified
#include <cuda_bf16.h>
#include <assert.h>
using bf = __nv_bfloat16;
__device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
__device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }
typedef bf * __restrict__ F_;
__global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
constexpr int C = _C_;
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
float state[C] = {0};
__shared__ float q[C], k[C], w[C], a[C], b[C];
for (int t = 0; t < T; t++) {
int ind = bb*T*H*C + t*H*C + hh * C + i;
__syncthreads();
q[i] = to_float(q_[ind]);
w[i] = __expf(-__expf(to_float(w_[ind])));
k[i] = to_float(k_[ind]);
a[i] = to_float(a_[ind]);
b[i] = to_float(b_[ind]);
__syncthreads();
float sa = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
sa += a[j] * state[j];
}
sa_[ind] = sa;
float v = to_float(v_[ind]);
float y = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
float& s = state[j];
s = s * w[j] + sa * b[j] + k[j] * v;
y += s * q[j];
}
y_[ind] = to_bf(y);
if ((t+1)%_CHUNK_LEN_ == 0) {
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
#pragma unroll
for (int j = 0; j < C; j++) {
s_[base + j*C] = state[j];
}
}
}
}
__global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
constexpr int C = _C_;
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
float qi, wi, ki, ai, bi, dyi;
for (int t = T-1; t >= 0; t--) {
int ind = bb*T*H*C + t*H*C + hh * C + i;
__syncthreads();
q[i] = qi = to_float(q_[ind]);
float wi_fac = -__expf(to_float(w_[ind]));
w[i] = wi = __expf(wi_fac);
k[i] = ki = to_float(k_[ind]);
a[i] = ai = to_float(a_[ind]);
b[i] = bi = to_float(b_[ind]);
v[i] = to_float(v_[ind]);
dy[i] = dyi = to_float(dy_[ind]);
sa[i] = sa_[ind];
__syncthreads();
if ((t+1)%_CHUNK_LEN_ == 0) {
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
#pragma unroll
for (int j = 0; j < C; j++) {
stateT[j] = s_[base + j];
}
}
float dq = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
dq += stateT[j]*dy[j];
}
dq_[ind] = to_bf(dq);
float iwi = 1.0f/wi;
#pragma unroll
for (int j = 0; j < C; j++) {
stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
dstate[j] += dyi * q[j];
dstateT[j] += qi * dy[j];
}
float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
dw += dstateT[j]*stateT[j];
dk += dstateT[j]*v[j];
dv += dstate[j]*k[j];
dSb += dstate[j]*b[j];
db += dstateT[j]*sa[j];
}
dw_[ind] = to_bf(dw * wi * wi_fac);
dk_[ind] = to_bf(dk);
dv_[ind] = to_bf(dv);
db_[ind] = to_bf(db);
__syncthreads();
dSb_shared[i] = dSb;
__syncthreads();
float da = 0;
#pragma unroll
for (int j = 0; j < C; j++) {
da += stateT[j]*dSb_shared[j];
}
da_[ind] = to_bf(da);
#pragma unroll
for (int j = 0; j < C; j++) {
dstate[j] = dstate[j]*w[j] + dSb * a[j];
dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
}
}
}
void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
forward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,y,s,sa);
}
void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
assert(T%_CHUNK_LEN_ == 0);
backward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
}