medmekk HF Staff commited on
Commit
1e759d6
·
verified ·
1 Parent(s): 3508bab

Upload custom kernels

Browse files
build/torch-universal/liger_kernels/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._liger_kernels_20250505100655
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_liger_kernels_20250505100655::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._liger_kernels_20250505101012
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_liger_kernels_20250505101012::{op_name}"
build/torch-universal/liger_kernels/cross_entropy.py CHANGED
@@ -6,10 +6,10 @@ import torch
6
  import triton
7
  import triton.language as tl
8
 
9
- from utils import compare_version
10
- from utils import element_mul_kernel
11
- from utils import is_hip
12
- from utils import infer_device
13
 
14
  if compare_version("triton", operator.ge, "3.0.0"):
15
  try:
 
6
  import triton
7
  import triton.language as tl
8
 
9
+ from .utils import compare_version
10
+ from .utils import element_mul_kernel
11
+ from .utils import is_hip
12
+ from .utils import infer_device
13
 
14
  if compare_version("triton", operator.ge, "3.0.0"):
15
  try:
build/torch-universal/liger_kernels/dyt.py CHANGED
@@ -4,10 +4,10 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import calculate_settings
8
- from utils import compare_version
9
- from utils import ensure_contiguous
10
- from utils import infer_device
11
 
12
  if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import calculate_settings
8
+ from .utils import compare_version
9
+ from .utils import ensure_contiguous
10
+ from .utils import infer_device
11
 
12
  if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
build/torch-universal/liger_kernels/fused_linear_cross_entropy.py CHANGED
@@ -1,11 +1,11 @@
1
  import torch
2
  import triton
3
 
4
- from cross_entropy import liger_cross_entropy_kernel
5
- from utils import amp_custom_bwd
6
- from utils import amp_custom_fwd
7
- from utils import element_mul_kernel
8
- from utils import is_hip
9
 
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
 
1
  import torch
2
  import triton
3
 
4
+ from .cross_entropy import liger_cross_entropy_kernel
5
+ from .utils import amp_custom_bwd
6
+ from .utils import amp_custom_fwd
7
+ from .utils import element_mul_kernel
8
+ from .utils import is_hip
9
 
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
build/torch-universal/liger_kernels/geglu.py CHANGED
@@ -4,9 +4,9 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import calculate_settings
8
- from utils import compare_version
9
- from utils import ensure_contiguous
10
 
11
  if compare_version("triton", operator.ge, "3.0.0"):
12
  try:
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import calculate_settings
8
+ from .utils import compare_version
9
+ from .utils import ensure_contiguous
10
 
11
  if compare_version("triton", operator.ge, "3.0.0"):
12
  try:
build/torch-universal/liger_kernels/group_norm.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import compare_version
8
- from utils import ensure_contiguous
9
 
10
  if compare_version("triton", operator.ge, "3.0.0"):
11
  try:
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import compare_version
8
+ from .utils import ensure_contiguous
9
 
10
  if compare_version("triton", operator.ge, "3.0.0"):
11
  try:
build/torch-universal/liger_kernels/jsd.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import ensure_contiguous
8
- from utils import infer_device
9
 
10
 
11
  @triton.jit
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import ensure_contiguous
8
+ from .utils import infer_device
9
 
10
 
11
  @triton.jit
build/torch-universal/liger_kernels/kl_div.py CHANGED
@@ -4,9 +4,9 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import ensure_contiguous
8
- from utils import is_hip
9
- from utils import infer_device
10
 
11
 
12
  def get_num_warps(BLOCK_SIZE):
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import ensure_contiguous
8
+ from .utils import is_hip
9
+ from .utils import infer_device
10
 
11
 
12
  def get_num_warps(BLOCK_SIZE):
build/torch-universal/liger_kernels/layer_norm.py CHANGED
@@ -5,9 +5,9 @@ import torch
5
  import triton
6
  import triton.language as tl
7
 
8
- from utils import calculate_settings
9
- from utils import compare_version
10
- from utils import ensure_contiguous
11
 
12
  if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
 
5
  import triton
6
  import triton.language as tl
7
 
8
+ from .utils import calculate_settings
9
+ from .utils import compare_version
10
+ from .utils import ensure_contiguous
11
 
12
  if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
build/torch-universal/liger_kernels/rms_norm.py CHANGED
@@ -17,10 +17,10 @@ import torch
17
  import triton
18
  import triton.language as tl
19
 
20
- from utils import calculate_settings
21
- from utils import compare_version
22
- from utils import ensure_contiguous
23
- from utils import torch_to_triton_dtype
24
 
25
  if compare_version("triton", operator.ge, "3.0.0"):
26
  try:
 
17
  import triton
18
  import triton.language as tl
19
 
20
+ from .utils import calculate_settings
21
+ from .utils import compare_version
22
+ from .utils import ensure_contiguous
23
+ from .utils import torch_to_triton_dtype
24
 
25
  if compare_version("triton", operator.ge, "3.0.0"):
26
  try:
build/torch-universal/liger_kernels/swiglu.py CHANGED
@@ -2,8 +2,8 @@ import torch
2
  import triton
3
  import triton.language as tl
4
 
5
- from utils import calculate_settings
6
- from utils import ensure_contiguous
7
 
8
 
9
  @triton.jit
 
2
  import triton
3
  import triton.language as tl
4
 
5
+ from .utils import calculate_settings
6
+ from .utils import ensure_contiguous
7
 
8
 
9
  @triton.jit
build/torch-universal/liger_kernels/tvd.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import triton
6
  import triton.language as tl
7
 
8
- from utils import ensure_contiguous
9
 
10
  MAX_FUSED_SIZE = 65536 // 4
11
 
 
5
  import triton
6
  import triton.language as tl
7
 
8
+ from .utils import ensure_contiguous
9
 
10
  MAX_FUSED_SIZE = 65536 // 4
11
 
torch-ext/liger_kernels/cross_entropy.py CHANGED
@@ -6,10 +6,10 @@ import torch
6
  import triton
7
  import triton.language as tl
8
 
9
- from utils import compare_version
10
- from utils import element_mul_kernel
11
- from utils import is_hip
12
- from utils import infer_device
13
 
14
  if compare_version("triton", operator.ge, "3.0.0"):
15
  try:
 
6
  import triton
7
  import triton.language as tl
8
 
9
+ from .utils import compare_version
10
+ from .utils import element_mul_kernel
11
+ from .utils import is_hip
12
+ from .utils import infer_device
13
 
14
  if compare_version("triton", operator.ge, "3.0.0"):
15
  try:
torch-ext/liger_kernels/dyt.py CHANGED
@@ -4,10 +4,10 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import calculate_settings
8
- from utils import compare_version
9
- from utils import ensure_contiguous
10
- from utils import infer_device
11
 
12
  if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import calculate_settings
8
+ from .utils import compare_version
9
+ from .utils import ensure_contiguous
10
+ from .utils import infer_device
11
 
12
  if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
torch-ext/liger_kernels/fused_linear_cross_entropy.py CHANGED
@@ -1,11 +1,11 @@
1
  import torch
2
  import triton
3
 
4
- from cross_entropy import liger_cross_entropy_kernel
5
- from utils import amp_custom_bwd
6
- from utils import amp_custom_fwd
7
- from utils import element_mul_kernel
8
- from utils import is_hip
9
 
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
 
1
  import torch
2
  import triton
3
 
4
+ from .cross_entropy import liger_cross_entropy_kernel
5
+ from .utils import amp_custom_bwd
6
+ from .utils import amp_custom_fwd
7
+ from .utils import element_mul_kernel
8
+ from .utils import is_hip
9
 
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
torch-ext/liger_kernels/geglu.py CHANGED
@@ -4,9 +4,9 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import calculate_settings
8
- from utils import compare_version
9
- from utils import ensure_contiguous
10
 
11
  if compare_version("triton", operator.ge, "3.0.0"):
12
  try:
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import calculate_settings
8
+ from .utils import compare_version
9
+ from .utils import ensure_contiguous
10
 
11
  if compare_version("triton", operator.ge, "3.0.0"):
12
  try:
torch-ext/liger_kernels/group_norm.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import compare_version
8
- from utils import ensure_contiguous
9
 
10
  if compare_version("triton", operator.ge, "3.0.0"):
11
  try:
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import compare_version
8
+ from .utils import ensure_contiguous
9
 
10
  if compare_version("triton", operator.ge, "3.0.0"):
11
  try:
torch-ext/liger_kernels/jsd.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import ensure_contiguous
8
- from utils import infer_device
9
 
10
 
11
  @triton.jit
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import ensure_contiguous
8
+ from .utils import infer_device
9
 
10
 
11
  @triton.jit
torch-ext/liger_kernels/kl_div.py CHANGED
@@ -4,9 +4,9 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from utils import ensure_contiguous
8
- from utils import is_hip
9
- from utils import infer_device
10
 
11
 
12
  def get_num_warps(BLOCK_SIZE):
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ from .utils import ensure_contiguous
8
+ from .utils import is_hip
9
+ from .utils import infer_device
10
 
11
 
12
  def get_num_warps(BLOCK_SIZE):
torch-ext/liger_kernels/layer_norm.py CHANGED
@@ -5,9 +5,9 @@ import torch
5
  import triton
6
  import triton.language as tl
7
 
8
- from utils import calculate_settings
9
- from utils import compare_version
10
- from utils import ensure_contiguous
11
 
12
  if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
 
5
  import triton
6
  import triton.language as tl
7
 
8
+ from .utils import calculate_settings
9
+ from .utils import compare_version
10
+ from .utils import ensure_contiguous
11
 
12
  if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
torch-ext/liger_kernels/rms_norm.py CHANGED
@@ -17,10 +17,10 @@ import torch
17
  import triton
18
  import triton.language as tl
19
 
20
- from utils import calculate_settings
21
- from utils import compare_version
22
- from utils import ensure_contiguous
23
- from utils import torch_to_triton_dtype
24
 
25
  if compare_version("triton", operator.ge, "3.0.0"):
26
  try:
 
17
  import triton
18
  import triton.language as tl
19
 
20
+ from .utils import calculate_settings
21
+ from .utils import compare_version
22
+ from .utils import ensure_contiguous
23
+ from .utils import torch_to_triton_dtype
24
 
25
  if compare_version("triton", operator.ge, "3.0.0"):
26
  try:
torch-ext/liger_kernels/swiglu.py CHANGED
@@ -2,8 +2,8 @@ import torch
2
  import triton
3
  import triton.language as tl
4
 
5
- from utils import calculate_settings
6
- from utils import ensure_contiguous
7
 
8
 
9
  @triton.jit
 
2
  import triton
3
  import triton.language as tl
4
 
5
+ from .utils import calculate_settings
6
+ from .utils import ensure_contiguous
7
 
8
 
9
  @triton.jit
torch-ext/liger_kernels/tvd.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import triton
6
  import triton.language as tl
7
 
8
- from utils import ensure_contiguous
9
 
10
  MAX_FUSED_SIZE = 65536 // 4
11
 
 
5
  import triton
6
  import triton.language as tl
7
 
8
+ from .utils import ensure_contiguous
9
 
10
  MAX_FUSED_SIZE = 65536 // 4
11