| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | #include <torch/extension.h>
|
| | #include <vector>
|
| |
|
| | |
| | |
| | |
| |
|
| | #ifdef WITH_CUDA
|
| | torch::Tensor bitlinear_cuda_forward(
|
| | torch::Tensor x,
|
| | torch::Tensor W_ternary,
|
| | torch::Tensor gamma,
|
| | torch::optional<torch::Tensor> bias
|
| | );
|
| |
|
| | torch::Tensor multi_ternary_cuda_forward(
|
| | torch::Tensor x,
|
| | torch::Tensor W_ternary,
|
| | torch::Tensor gammas,
|
| | torch::optional<torch::Tensor> bias
|
| | );
|
| | #endif
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | torch::Tensor bitlinear_cpu_forward(
|
| | torch::Tensor x,
|
| | torch::Tensor W_ternary,
|
| | torch::Tensor gamma,
|
| | torch::optional<torch::Tensor> bias
|
| | ) {
|
| |
|
| | auto x_shape = x.sizes().vec();
|
| | int64_t batch_size = 1;
|
| | for (size_t i = 0; i < x_shape.size() - 1; i++) {
|
| | batch_size *= x_shape[i];
|
| | }
|
| | int64_t in_features = x_shape.back();
|
| | int64_t out_features = W_ternary.size(0);
|
| |
|
| |
|
| | auto x_2d = x.view({batch_size, in_features});
|
| |
|
| |
|
| |
|
| | auto output = torch::matmul(x_2d, W_ternary.t());
|
| |
|
| |
|
| |
|
| | output = output * gamma.unsqueeze(0);
|
| |
|
| |
|
| | if (bias.has_value() && bias.value().defined()) {
|
| | output = output + bias.value().unsqueeze(0);
|
| | }
|
| |
|
| |
|
| | std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
|
| | out_shape.push_back(out_features);
|
| | output = output.view(out_shape);
|
| |
|
| | return output;
|
| | }
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | torch::Tensor multi_ternary_cpu_forward(
|
| | torch::Tensor x,
|
| | torch::Tensor W_ternary,
|
| | torch::Tensor gammas,
|
| | torch::optional<torch::Tensor> bias
|
| | ) {
|
| |
|
| |
|
| | int64_t k = W_ternary.size(0);
|
| | int64_t out_features = W_ternary.size(1);
|
| | int64_t in_features = W_ternary.size(2);
|
| |
|
| |
|
| | auto x_shape = x.sizes().vec();
|
| | int64_t batch_size = 1;
|
| | for (size_t i = 0; i < x_shape.size() - 1; i++) {
|
| | batch_size *= x_shape[i];
|
| | }
|
| |
|
| |
|
| | auto x_2d = x.view({batch_size, in_features});
|
| |
|
| |
|
| | auto output = torch::zeros({batch_size, out_features}, x.options());
|
| |
|
| |
|
| | for (int64_t i = 0; i < k; i++) {
|
| |
|
| | auto W_i = W_ternary[i];
|
| | auto gamma_i = gammas[i];
|
| |
|
| |
|
| | auto component = torch::matmul(x_2d, W_i.t());
|
| | component = component * gamma_i.unsqueeze(0);
|
| |
|
| |
|
| | output = output + component;
|
| | }
|
| |
|
| |
|
| | if (bias.has_value() && bias.value().defined()) {
|
| | output = output + bias.value().unsqueeze(0);
|
| | }
|
| |
|
| |
|
| | std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
|
| | out_shape.push_back(out_features);
|
| | output = output.view(out_shape);
|
| |
|
| | return output;
|
| | }
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | torch::Tensor bitlinear_forward(
|
| | torch::Tensor x,
|
| | torch::Tensor W_ternary,
|
| | torch::Tensor gamma,
|
| | torch::optional<torch::Tensor> bias
|
| | ) {
|
| |
|
| | TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions");
|
| | TORCH_CHECK(W_ternary.dim() == 2, "W_ternary must be 2D");
|
| | TORCH_CHECK(gamma.dim() == 1 || gamma.dim() == 2, "gamma must be 1D or 2D");
|
| |
|
| |
|
| | if (x.is_cuda()) {
|
| | #ifdef WITH_CUDA
|
| | return bitlinear_cuda_forward(x, W_ternary, gamma, bias);
|
| | #else
|
| | AT_ERROR("BitLinear CUDA kernels not compiled. Rebuild with CUDA support.");
|
| | #endif
|
| | } else {
|
| | return bitlinear_cpu_forward(x, W_ternary, gamma, bias);
|
| | }
|
| | }
|
| |
|
| | |
| | |
| |
|
| | torch::Tensor multi_ternary_forward(
|
| | torch::Tensor x,
|
| | torch::Tensor W_ternary,
|
| | torch::Tensor gammas,
|
| | torch::optional<torch::Tensor> bias
|
| | ) {
|
| |
|
| | TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions");
|
| | TORCH_CHECK(W_ternary.dim() == 3, "W_ternary must be 3D [k, out_features, in_features]");
|
| | TORCH_CHECK(gammas.dim() == 2, "gammas must be 2D [k, out_features]");
|
| |
|
| |
|
| | if (x.is_cuda()) {
|
| | #ifdef WITH_CUDA
|
| | return multi_ternary_cuda_forward(x, W_ternary, gammas, bias);
|
| | #else
|
| | AT_ERROR("Multi-ternary CUDA kernels not compiled. Rebuild with CUDA support.");
|
| | #endif
|
| | } else {
|
| | return multi_ternary_cpu_forward(x, W_ternary, gammas, bias);
|
| | }
|
| | }
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | torch::Tensor pack_ternary_base3_cpp(torch::Tensor W_ternary) {
|
| |
|
| | auto flat = W_ternary.flatten().to(torch::kCPU).to(torch::kInt8);
|
| | int64_t numel = flat.numel();
|
| |
|
| |
|
| | auto mapped = (flat + 1).to(torch::kUInt8);
|
| |
|
| |
|
| | int64_t packed_size = (numel + 4) / 5;
|
| | auto packed = torch::zeros({packed_size}, torch::dtype(torch::kUInt8).device(torch::kCPU));
|
| |
|
| |
|
| | auto mapped_ptr = mapped.data_ptr<uint8_t>();
|
| | auto packed_ptr = packed.data_ptr<uint8_t>();
|
| |
|
| |
|
| | const uint8_t powers[5] = {1, 3, 9, 27, 81};
|
| |
|
| |
|
| | for (int64_t i = 0; i < packed_size; i++) {
|
| | int64_t base_idx = i * 5;
|
| | uint8_t packed_val = 0;
|
| |
|
| | for (int j = 0; j < 5; j++) {
|
| | int64_t idx = base_idx + j;
|
| | if (idx < numel) {
|
| | packed_val += mapped_ptr[idx] * powers[j];
|
| | } else {
|
| |
|
| | packed_val += 1 * powers[j];
|
| | }
|
| | }
|
| | packed_ptr[i] = packed_val;
|
| | }
|
| |
|
| | return packed;
|
| | }
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | torch::Tensor unpack_ternary_base3_cpp(
|
| | torch::Tensor packed,
|
| | std::vector<int64_t> original_shape
|
| | ) {
|
| |
|
| | int64_t numel = 1;
|
| | for (auto dim : original_shape) {
|
| | numel *= dim;
|
| | }
|
| |
|
| |
|
| | auto packed_flat = packed.flatten().to(torch::kCPU).to(torch::kUInt8);
|
| | int64_t packed_size = packed_flat.numel();
|
| |
|
| |
|
| | auto unpacked = torch::zeros({numel}, torch::dtype(torch::kInt8).device(torch::kCPU));
|
| |
|
| |
|
| | auto packed_ptr = packed_flat.data_ptr<uint8_t>();
|
| | auto unpacked_ptr = unpacked.data_ptr<int8_t>();
|
| |
|
| |
|
| | int64_t out_idx = 0;
|
| | for (int64_t i = 0; i < packed_size && out_idx < numel; i++) {
|
| | uint8_t packed_val = packed_ptr[i];
|
| |
|
| |
|
| | for (int j = 0; j < 5 && out_idx < numel; j++) {
|
| | uint8_t val = packed_val % 3;
|
| | packed_val /= 3;
|
| |
|
| |
|
| | unpacked_ptr[out_idx] = static_cast<int8_t>(val) - 1;
|
| | out_idx++;
|
| | }
|
| | }
|
| |
|
| |
|
| | return unpacked.view(original_shape).to(torch::kFloat32);
|
| | }
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| | m.def("forward", &bitlinear_forward, "BitLinear forward (CPU/CUDA)",
|
| | py::arg("x"),
|
| | py::arg("W_ternary"),
|
| | py::arg("gamma"),
|
| | py::arg("bias") = py::none());
|
| |
|
| | m.def("multi_ternary_forward", &multi_ternary_forward,
|
| | "Multi-ternary linear forward (CPU/CUDA)",
|
| | py::arg("x"),
|
| | py::arg("W_ternary"),
|
| | py::arg("gammas"),
|
| | py::arg("bias") = py::none());
|
| |
|
| | m.def("pack_ternary_base3", &pack_ternary_base3_cpp,
|
| | "Pack ternary weights to base-3 (CPU)",
|
| | py::arg("W_ternary"));
|
| |
|
| | m.def("unpack_ternary_base3", &unpack_ternary_base3_cpp,
|
| | "Unpack base-3 ternary weights (CPU)",
|
| | py::arg("packed"),
|
| | py::arg("original_shape"));
|
| | }
|
| |
|