Spaces:
Runtime error
Runtime error
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
const ggml_tensor * KQV = dst; | |
const ggml_tensor * Q = dst->src[0]; | |
const int32_t precision = KQV->op_params[3]; | |
if (precision != GGML_PREC_DEFAULT) { | |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) { | |
constexpr int cols_per_block = 16; | |
switch (Q->ne[0]) { | |
case 64: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); | |
break; | |
case 80: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); | |
break; | |
case 96: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); | |
break; | |
case 112: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); | |
break; | |
case 128: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); | |
break; | |
case 256: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); | |
break; | |
default: | |
GGML_ABORT("fatal error"); | |
break; | |
} | |
} else { | |
constexpr int cols_per_block = 32; | |
switch (Q->ne[0]) { | |
case 64: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); | |
break; | |
case 80: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); | |
break; | |
case 96: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); | |
break; | |
case 112: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); | |
break; | |
case 128: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); | |
break; | |
// case 256: | |
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); | |
// break; | |
default: | |
GGML_ABORT("fatal error"); | |
break; | |
} | |
} | |
return; | |
} | |
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { | |
constexpr int cols_per_block = 8; | |
switch (Q->ne[0]) { | |
case 64: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); | |
break; | |
case 96: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); | |
break; | |
case 128: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); | |
break; | |
case 256: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); | |
break; | |
default: | |
GGML_ABORT("fatal error"); | |
break; | |
} | |
return; | |
} | |
if (Q->ne[1] <= 32) { | |
constexpr int cols_per_block = 16; | |
switch (Q->ne[0]) { | |
case 64: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); | |
break; | |
case 80: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); | |
break; | |
case 96: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); | |
break; | |
case 112: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); | |
break; | |
case 128: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); | |
break; | |
case 256: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); | |
break; | |
default: | |
GGML_ABORT("fatal error"); | |
break; | |
} | |
return; | |
} | |
constexpr int cols_per_block = 32; | |
switch (Q->ne[0]) { | |
case 64: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); | |
break; | |
case 80: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); | |
break; | |
case 96: | |
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); | |
break; | |
case 112: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); | |
break; | |
case 128: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); | |
break; | |
case 256: | |
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); | |
break; | |
default: | |
GGML_ABORT("fatal error"); | |
break; | |
} | |
} | |
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
ggml_tensor * Q = dst->src[0]; | |
ggml_tensor * K = dst->src[1]; | |
ggml_tensor * V = dst->src[2]; | |
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) | |
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) | |
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) | |
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) | |
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) | |
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 ) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) | |
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | |
on_no_fattn_vec_case(Q->ne[0]); | |
} | |
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
ggml_tensor * Q = dst->src[0]; | |
ggml_tensor * K = dst->src[1]; | |
ggml_tensor * V = dst->src[2]; | |
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) | |
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) | |
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) | |
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) | |
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) | |
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) | |
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) | |
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | |
on_no_fattn_vec_case(Q->ne[0]); | |
} | |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
const ggml_tensor * KQV = dst; | |
const ggml_tensor * Q = dst->src[0]; | |
ggml_cuda_set_device(ctx.device); | |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | |
const int32_t precision = KQV->op_params[3]; | |
// On AMD the tile kernels perform poorly, use the vec kernel instead: | |
if (cc >= CC_OFFSET_AMD) { | |
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { | |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | |
} else { | |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | |
} | |
return; | |
} | |
if (!fast_fp16_available(cc)) { | |
if (Q->ne[1] <= 8 || Q->ne[0] == 256) { | |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | |
} else { | |
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); | |
} | |
return; | |
} | |
if (!fp16_mma_available(cc)) { | |
if (Q->ne[1] <= 8) { | |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | |
} else { | |
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); | |
} | |
return; | |
} | |
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { | |
if (precision == GGML_PREC_DEFAULT) { | |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | |
return; | |
} else if(Q->ne[0] <= 128) { | |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | |
return; | |
} | |
} | |
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | |
} | |