quantization / LSAQ_CoreCode /lsaq_quant.py
chen459664's picture
Add files using upload-large-folder tool
02bc7b8 verified
Raw
History Blame Contribute Delete
5.29 kB
import os
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import tqdm
import json
import math
import torch.nn.functional as F
from datasets import load_dataset
@torch.no_grad()
def quantize_weight_per_channel_absmax(w, n_bits=8):
# w: (out_features, in_features)
scales = w.abs().max(dim=-1, keepdim=True)[0]
q_max = 2 ** (n_bits - 1) - 1
scales.clamp_(min=1e-5).div_(q_max)
w.div_(scales).round_().mul_(scales)
return w
@torch.no_grad()
def quantize_weight_per_tensor_absmax(w, n_bits=8):
# w: (out_features, in_features)
scales = w.abs().max()
q_max = 2 ** (n_bits - 1) - 1
scales.clamp_(min=1e-5).div_(q_max)
w.div_(scales).round_().mul_(scales)
return w
class W8A16Linear(nn.Module):
def __init__(
self,
# bit_width,
in_features,
out_features,
bias=True,
quantize_output=False,
):
super().__init__()
# self.bit_width = bit_width
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.randn(
self.out_features,
self.in_features,
dtype=torch.float16,
requires_grad=False,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(1, self.out_features), dtype=torch.float16, requires_grad=False
),
)
else:
self.register_buffer("bias", None)
def to(self, *args, **kwargs):
super(W8A16Linear, self).to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
if self.bias is not None:
self.bias = self.bias.to(*args, **kwargs)
return self
@torch.no_grad()
def forward(self, x):
y = torch.functional.F.linear(x, self.weight, self.bias)
return y
@staticmethod
def from_float(
bit, module, weight_quant="per_channel", quantize_output=False
):
assert isinstance(module, torch.nn.Linear)
new_module = W8A16Linear(
# bit,
module.in_features,
module.out_features,
module.bias is not None,
quantize_output=quantize_output,
)
if weight_quant == "per_channel":
new_module.weight = quantize_weight_per_channel_absmax(module.weight, bit)
elif weight_quant == "per_tensor":
new_module.weight = quantize_weight_per_tensor_absmax(module.weight, bit)
else:
raise ValueError(f"Invalid weight_quant: {weight_quant}")
new_module.weight_quant_name = weight_quant
if module.bias is not None:
new_module.bias = module.bias
return new_module
def __repr__(self):
return f"W8A16Linear({self.in_features}, {self.out_features}, bias={self.bias is not None}, weight_quant={self.weight_quant_name})"
def quantize_llama_like(
model, mlp_quant, self_attn_quant, low_bit, weight_quant="per_channel", quantize_bmm_input=False
):
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
)
for name, m in model.model.named_modules():
if isinstance(m, LlamaMLP):
if low_bit == 0:
continue
else:
if name in mlp_quant:
bit = low_bit
print(f'{name} {bit} bit quant ')
else:
if low_bit == 4:
bit = 8
print(f'{name} {bit} bit quant ')
elif low_bit == 8:
continue
m.gate_proj = W8A16Linear.from_float(
bit, m.gate_proj, weight_quant=weight_quant
)
m.up_proj = W8A16Linear.from_float(
bit, m.up_proj, weight_quant=weight_quant
)
m.down_proj = W8A16Linear.from_float(
bit, m.down_proj, weight_quant=weight_quant
)
elif isinstance(m, LlamaAttention):
if low_bit == 0:
continue
else:
if name in self_attn_quant:
bit = low_bit
else:
if low_bit == 4:
bit = 8
elif low_bit == 8:
continue
m.q_proj = W8A16Linear.from_float(
bit,
m.q_proj,
weight_quant=weight_quant,
quantize_output=quantize_bmm_input,
)
m.k_proj = W8A16Linear.from_float(
bit,
m.k_proj,
weight_quant=weight_quant,
quantize_output=quantize_bmm_input,
)
m.v_proj = W8A16Linear.from_float(
bit,
m.v_proj,
weight_quant=weight_quant,
quantize_output=quantize_bmm_input,
)
m.o_proj = W8A16Linear.from_float(
bit, m.o_proj, weight_quant=weight_quant
)
return model