update kernels
Browse files
kernels/cache_autogptq_cuda_256.cpp → cache_autogptq_cuda_256.cpp
RENAMED
|
File without changes
|
kernels/cache_autogptq_cuda_kernel_256.cu → cache_autogptq_cuda_kernel_256.cu
RENAMED
|
File without changes
|
kernels/cpp_kernels.py → cpp_kernels.py
RENAMED
|
@@ -50,6 +50,6 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
|
| 50 |
|
| 51 |
extra_flags = []
|
| 52 |
|
| 53 |
-
cache_autogptq_cuda_256_sources = ["./
|
| 54 |
-
"./
|
| 55 |
cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
|
|
|
|
| 50 |
|
| 51 |
extra_flags = []
|
| 52 |
|
| 53 |
+
cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp",
|
| 54 |
+
"./cache_autogptq_cuda_kernel_256.cu"]
|
| 55 |
cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
|
modeling_qwen.py
CHANGED
|
@@ -32,11 +32,6 @@ except ImportError:
|
|
| 32 |
rearrange = None
|
| 33 |
from torch import nn
|
| 34 |
|
| 35 |
-
try:
|
| 36 |
-
from kernels.cpp_kernels import cache_autogptq_cuda_256
|
| 37 |
-
except ImportError:
|
| 38 |
-
cache_autogptq_cuda_256 = None
|
| 39 |
-
|
| 40 |
SUPPORT_CUDA = torch.cuda.is_available()
|
| 41 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
| 42 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
|
@@ -294,14 +289,21 @@ class QWenAttention(nn.Module):
|
|
| 294 |
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
| 295 |
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
|
| 298 |
device = query.device
|
| 299 |
if self.use_cache_quantization:
|
| 300 |
qk, qk_scale, qk_zero = key
|
| 301 |
-
if self.use_cache_kernel and
|
| 302 |
shape = query.shape[:-1] + (qk.shape[-2],)
|
| 303 |
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
| 304 |
-
|
| 305 |
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
|
| 306 |
qk.transpose(-1, -2).contiguous(),
|
| 307 |
attn_weights,
|
|
@@ -353,10 +355,10 @@ class QWenAttention(nn.Module):
|
|
| 353 |
|
| 354 |
if self.use_cache_quantization:
|
| 355 |
qv, qv_scale, qv_zero = value
|
| 356 |
-
if self.use_cache_kernel and
|
| 357 |
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
| 358 |
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
| 359 |
-
|
| 360 |
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
|
| 361 |
qv.contiguous(), # dtype: int32
|
| 362 |
attn_output,
|
|
@@ -1022,15 +1024,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
| 1022 |
if config.use_flash_attn:
|
| 1023 |
_import_flash_attn()
|
| 1024 |
|
| 1025 |
-
|
| 1026 |
-
if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
|
| 1027 |
-
config.use_flash_attn = False
|
| 1028 |
-
if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
|
| 1029 |
-
try:
|
| 1030 |
-
from kernels.cpp_kernels import cache_autogptq_cuda_256
|
| 1031 |
-
except ImportError:
|
| 1032 |
-
cache_autogptq_cuda_256 = None
|
| 1033 |
-
|
| 1034 |
self.transformer = QWenModel(config)
|
| 1035 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1036 |
|
|
|
|
| 32 |
rearrange = None
|
| 33 |
from torch import nn
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
SUPPORT_CUDA = torch.cuda.is_available()
|
| 36 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
| 37 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
|
|
|
| 289 |
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
| 290 |
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
| 291 |
|
| 292 |
+
if config.use_cache_quantization and config.use_cache_kernel:
|
| 293 |
+
from .cpp_kernels import cache_autogptq_cuda_256
|
| 294 |
+
try:
|
| 295 |
+
self.cache_kernels = cache_autogptq_cuda_256
|
| 296 |
+
except ImportError:
|
| 297 |
+
self.cache_kernels = None
|
| 298 |
+
|
| 299 |
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
|
| 300 |
device = query.device
|
| 301 |
if self.use_cache_quantization:
|
| 302 |
qk, qk_scale, qk_zero = key
|
| 303 |
+
if self.use_cache_kernel and self.cache_kernels is not None:
|
| 304 |
shape = query.shape[:-1] + (qk.shape[-2],)
|
| 305 |
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
| 306 |
+
self.cache_kernels.vecquant8matmul_batched_faster_old(
|
| 307 |
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
|
| 308 |
qk.transpose(-1, -2).contiguous(),
|
| 309 |
attn_weights,
|
|
|
|
| 355 |
|
| 356 |
if self.use_cache_quantization:
|
| 357 |
qv, qv_scale, qv_zero = value
|
| 358 |
+
if self.use_cache_kernel and self.cache_kernels is not None:
|
| 359 |
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
| 360 |
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
| 361 |
+
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
|
| 362 |
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
|
| 363 |
qv.contiguous(), # dtype: int32
|
| 364 |
attn_output,
|
|
|
|
| 1024 |
if config.use_flash_attn:
|
| 1025 |
_import_flash_attn()
|
| 1026 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1027 |
self.transformer = QWenModel(config)
|
| 1028 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1029 |
|