EC2 Default User
commited on
Commit
·
538355f
1
Parent(s):
a63f4de
updates
Browse files- build.toml +2 -0
- torch-ext/torch_binding.cpp +2 -2
- torch-ext/torch_binding.h +1 -1
build.toml
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
[general]
|
| 2 |
name = "rmsnorm_kernel"
|
|
|
|
| 3 |
|
| 4 |
[torch]
|
| 5 |
src = [
|
|
@@ -8,6 +9,7 @@ src = [
|
|
| 8 |
]
|
| 9 |
|
| 10 |
[kernel.rmsnorm_kernel]
|
|
|
|
| 11 |
src = [
|
| 12 |
"rmsnorm_kernel/rmsnorm.cu",
|
| 13 |
]
|
|
|
|
| 1 |
[general]
|
| 2 |
name = "rmsnorm_kernel"
|
| 3 |
+
universal = false
|
| 4 |
|
| 5 |
[torch]
|
| 6 |
src = [
|
|
|
|
| 9 |
]
|
| 10 |
|
| 11 |
[kernel.rmsnorm_kernel]
|
| 12 |
+
backend = "cuda"
|
| 13 |
src = [
|
| 14 |
"rmsnorm_kernel/rmsnorm.cu",
|
| 15 |
]
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -4,8 +4,8 @@
|
|
| 4 |
#include "torch_binding.h"
|
| 5 |
|
| 6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
-
ops.def("rmsnorm_forward(Tensor input, Tensor gamma) ->
|
| 8 |
ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
|
| 9 |
}
|
| 10 |
|
| 11 |
-
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
| 4 |
#include "torch_binding.h"
|
| 5 |
|
| 6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
ops.def("rmsnorm_forward(Tensor input, Tensor gamma) -> Tensor");
|
| 8 |
ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
|
| 9 |
}
|
| 10 |
|
| 11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
CHANGED
|
@@ -2,4 +2,4 @@
|
|
| 2 |
|
| 3 |
#include <torch/torch.h>
|
| 4 |
|
| 5 |
-
|
|
|
|
| 2 |
|
| 3 |
#include <torch/torch.h>
|
| 4 |
|
| 5 |
+
torch::Tensor rmsnorm_forward(torch::Tensor const &input, torch::Tensor const &gamma);
|