MekkCyber
commited on
Commit
·
81d263b
1
Parent(s):
2fafa6a
adding some activations
Browse files
activation/activation_kernels.cu
CHANGED
|
@@ -44,7 +44,7 @@ __device__ __forceinline__ T gelu_kernel(const T& x) {
|
|
| 44 |
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
| 45 |
const float f = (float)x;
|
| 46 |
constexpr float ALPHA = M_SQRT1_2;
|
| 47 |
-
return (T)(f * 0.5f * (1.0f +
|
| 48 |
}
|
| 49 |
|
| 50 |
template <typename T>
|
|
@@ -183,6 +183,7 @@ __global__ void activation_kernel(
|
|
| 183 |
|
| 184 |
namespace vllm {
|
| 185 |
|
|
|
|
| 186 |
template <typename T>
|
| 187 |
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
| 188 |
const float x3 = (float)(x * x * x);
|
|
@@ -223,3 +224,22 @@ void gelu_quick(torch::Tensor& out, // [..., d]
|
|
| 223 |
{
|
| 224 |
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
|
| 225 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
| 45 |
const float f = (float)x;
|
| 46 |
constexpr float ALPHA = M_SQRT1_2;
|
| 47 |
+
return (T)(f * 0.5f * (1.0f + erf(f * ALPHA)));
|
| 48 |
}
|
| 49 |
|
| 50 |
template <typename T>
|
|
|
|
| 183 |
|
| 184 |
namespace vllm {
|
| 185 |
|
| 186 |
+
|
| 187 |
template <typename T>
|
| 188 |
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
| 189 |
const float x3 = (float)(x * x * x);
|
|
|
|
| 224 |
{
|
| 225 |
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
|
| 226 |
}
|
| 227 |
+
|
| 228 |
+
void gelu(torch::Tensor& out, // [..., d]
|
| 229 |
+
torch::Tensor& input,
|
| 230 |
+
std::string approximation) // [..., d]
|
| 231 |
+
{
|
| 232 |
+
if (approximation == "none") {
|
| 233 |
+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_kernel);
|
| 234 |
+
} else if (approximation == "tanh") {
|
| 235 |
+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_tanh_kernel);
|
| 236 |
+
} else {
|
| 237 |
+
throw std::invalid_argument("Invalid approximation");
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
void silu(torch::Tensor& out, // [..., d]
|
| 242 |
+
torch::Tensor& input) // [..., d]
|
| 243 |
+
{
|
| 244 |
+
LAUNCH_ACTIVATION_KERNEL(vllm::silu_kernel);
|
| 245 |
+
}
|
torch-ext/activation/__init__.py
CHANGED
|
@@ -30,6 +30,15 @@ def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0)
|
|
| 30 |
return out
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 34 |
ops.gelu_fast(out, x)
|
| 35 |
return out
|
|
|
|
| 30 |
return out
|
| 31 |
|
| 32 |
|
| 33 |
+
def gelu(out: torch.Tensor, x: torch.Tensor, approximation: str = "none") -> None:
|
| 34 |
+
ops.gelu(out, x, approximation)
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
def silu(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 38 |
+
ops.silu(out, x)
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 43 |
ops.gelu_fast(out, x)
|
| 44 |
return out
|
torch-ext/activation/layers.py
CHANGED
|
@@ -23,6 +23,39 @@ class SiluAndMul(nn.Module):
|
|
| 23 |
ops.silu_and_mul(out, x)
|
| 24 |
return out
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
class MulAndSilu(nn.Module):
|
| 28 |
"""An activation function for SwiGLU.
|
|
|
|
| 23 |
ops.silu_and_mul(out, x)
|
| 24 |
return out
|
| 25 |
|
| 26 |
+
class Silu(nn.Module):
|
| 27 |
+
"""An activation function for SiLU.
|
| 28 |
+
|
| 29 |
+
The function computes x -> silu(x).
|
| 30 |
+
|
| 31 |
+
Shapes:
|
| 32 |
+
x: (num_tokens, d) or (batch_size, seq_len, d)
|
| 33 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
can_torch_compile: bool = True
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor):
|
| 39 |
+
out = torch.empty_like(x)
|
| 40 |
+
ops.silu(out, x)
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
class Gelu(nn.Module):
|
| 44 |
+
"""An activation function for GELU.
|
| 45 |
+
|
| 46 |
+
The function computes x -> gelu(x).
|
| 47 |
+
|
| 48 |
+
Shapes:
|
| 49 |
+
x: (num_tokens, d) or (batch_size, seq_len, d)
|
| 50 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
can_torch_compile: bool = True
|
| 54 |
+
|
| 55 |
+
def forward(self, x: torch.Tensor, approximation: str = "none"):
|
| 56 |
+
out = torch.empty_like(x)
|
| 57 |
+
ops.gelu(out, x, approximation)
|
| 58 |
+
return out
|
| 59 |
|
| 60 |
class MulAndSilu(nn.Module):
|
| 61 |
"""An activation function for SwiGLU.
|