MekkCyber commited on
Commit
81d263b
·
1 Parent(s): 2fafa6a

adding some activations

Browse files
activation/activation_kernels.cu CHANGED
@@ -44,7 +44,7 @@ __device__ __forceinline__ T gelu_kernel(const T& x) {
44
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
45
  const float f = (float)x;
46
  constexpr float ALPHA = M_SQRT1_2;
47
- return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
48
  }
49
 
50
  template <typename T>
@@ -183,6 +183,7 @@ __global__ void activation_kernel(
183
 
184
  namespace vllm {
185
 
 
186
  template <typename T>
187
  __device__ __forceinline__ T gelu_new_kernel(const T& x) {
188
  const float x3 = (float)(x * x * x);
@@ -223,3 +224,22 @@ void gelu_quick(torch::Tensor& out, // [..., d]
223
  {
224
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
225
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
45
  const float f = (float)x;
46
  constexpr float ALPHA = M_SQRT1_2;
47
+ return (T)(f * 0.5f * (1.0f + erf(f * ALPHA)));
48
  }
49
 
50
  template <typename T>
 
183
 
184
  namespace vllm {
185
 
186
+
187
  template <typename T>
188
  __device__ __forceinline__ T gelu_new_kernel(const T& x) {
189
  const float x3 = (float)(x * x * x);
 
224
  {
225
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
226
  }
227
+
228
+ void gelu(torch::Tensor& out, // [..., d]
229
+ torch::Tensor& input,
230
+ std::string approximation) // [..., d]
231
+ {
232
+ if (approximation == "none") {
233
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_kernel);
234
+ } else if (approximation == "tanh") {
235
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_tanh_kernel);
236
+ } else {
237
+ throw std::invalid_argument("Invalid approximation");
238
+ }
239
+ }
240
+
241
+ void silu(torch::Tensor& out, // [..., d]
242
+ torch::Tensor& input) // [..., d]
243
+ {
244
+ LAUNCH_ACTIVATION_KERNEL(vllm::silu_kernel);
245
+ }
torch-ext/activation/__init__.py CHANGED
@@ -30,6 +30,15 @@ def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0)
30
  return out
31
 
32
 
 
 
 
 
 
 
 
 
 
33
  def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
34
  ops.gelu_fast(out, x)
35
  return out
 
30
  return out
31
 
32
 
33
+ def gelu(out: torch.Tensor, x: torch.Tensor, approximation: str = "none") -> None:
34
+ ops.gelu(out, x, approximation)
35
+ return out
36
+
37
+ def silu(out: torch.Tensor, x: torch.Tensor) -> None:
38
+ ops.silu(out, x)
39
+ return out
40
+
41
+
42
  def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
43
  ops.gelu_fast(out, x)
44
  return out
torch-ext/activation/layers.py CHANGED
@@ -23,6 +23,39 @@ class SiluAndMul(nn.Module):
23
  ops.silu_and_mul(out, x)
24
  return out
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  class MulAndSilu(nn.Module):
28
  """An activation function for SwiGLU.
 
23
  ops.silu_and_mul(out, x)
24
  return out
25
 
26
+ class Silu(nn.Module):
27
+ """An activation function for SiLU.
28
+
29
+ The function computes x -> silu(x).
30
+
31
+ Shapes:
32
+ x: (num_tokens, d) or (batch_size, seq_len, d)
33
+ return: (num_tokens, d) or (batch_size, seq_len, d)
34
+ """
35
+
36
+ can_torch_compile: bool = True
37
+
38
+ def forward(self, x: torch.Tensor):
39
+ out = torch.empty_like(x)
40
+ ops.silu(out, x)
41
+ return out
42
+
43
+ class Gelu(nn.Module):
44
+ """An activation function for GELU.
45
+
46
+ The function computes x -> gelu(x).
47
+
48
+ Shapes:
49
+ x: (num_tokens, d) or (batch_size, seq_len, d)
50
+ return: (num_tokens, d) or (batch_size, seq_len, d)
51
+ """
52
+
53
+ can_torch_compile: bool = True
54
+
55
+ def forward(self, x: torch.Tensor, approximation: str = "none"):
56
+ out = torch.empty_like(x)
57
+ ops.gelu(out, x, approximation)
58
+ return out
59
 
60
  class MulAndSilu(nn.Module):
61
  """An activation function for SwiGLU.