EC2 Default User commited on
Commit
538355f
·
1 Parent(s): a63f4de
build.toml CHANGED
@@ -1,5 +1,6 @@
1
  [general]
2
  name = "rmsnorm_kernel"
 
3
 
4
  [torch]
5
  src = [
@@ -8,6 +9,7 @@ src = [
8
  ]
9
 
10
  [kernel.rmsnorm_kernel]
 
11
  src = [
12
  "rmsnorm_kernel/rmsnorm.cu",
13
  ]
 
1
  [general]
2
  name = "rmsnorm_kernel"
3
+ universal = false
4
 
5
  [torch]
6
  src = [
 
9
  ]
10
 
11
  [kernel.rmsnorm_kernel]
12
+ backend = "cuda"
13
  src = [
14
  "rmsnorm_kernel/rmsnorm.cu",
15
  ]
torch-ext/torch_binding.cpp CHANGED
@@ -4,8 +4,8 @@
4
  #include "torch_binding.h"
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
- ops.def("rmsnorm_forward(Tensor input, Tensor gamma) -> ()");
8
  ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
9
  }
10
 
11
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
4
  #include "torch_binding.h"
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("rmsnorm_forward(Tensor input, Tensor gamma) -> Tensor");
8
  ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
9
  }
10
 
11
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h CHANGED
@@ -2,4 +2,4 @@
2
 
3
  #include <torch/torch.h>
4
 
5
- void rmsnorm_forward(torch::Tensor const &input, torch::Tensor const &gamma);
 
2
 
3
  #include <torch/torch.h>
4
 
5
+ torch::Tensor rmsnorm_forward(torch::Tensor const &input, torch::Tensor const &gamma);