chen459664's picture
Add files using upload-large-folder tool
64ddf8d verified
Raw
History Blame Contribute Delete
14.9 kB
import functools
import gc
import inspect
import os
import re
import time
from collections import defaultdict
from typing import List
import pandas as pd
import torch
import torch.nn as nn
from bitsandbytes.functional import dequantize_nf4
from bitsandbytes.nn import Params4bit
from datasets import load_dataset
from hqq.core.quantize import Quantizer as hQuant
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# nbits is defined as placeholder to be consistent with other quant methods
def quant_nf4(tensor, nbits=4, group_size=64):
qw = Params4bit(tensor, blocksize=group_size, quant_type="nf4").to(tensor.device)
return dequantize_nf4(qw.data, qw.quant_state)
def quant_hqq(tensor, nbits, group_size=64, optimize=True):
wq, meta = hQuant.quantize(
tensor, nbits=nbits, group_size=group_size, optimize=optimize
)
return hQuant.dequantize(wq, meta)
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
part="attn_in",
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
part="attn_out",
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
part="mlp_gate",
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
part="mlp_down",
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
def clear_memory(weight=None):
if weight is not None:
del weight
gc.collect()
torch.cuda.empty_cache()
def get_best_device(idx=None):
if os.environ.get("USE_CPU_FOR_SENSITIVITY", None) == "1":
return "cpu"
if torch.cuda.is_available():
if idx is None:
return "cuda:0"
else:
return "cuda:" + str(idx % torch.cuda.device_count())
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def get_calib_dataset(
data: str = "pileval",
tokenizer=None,
n_samples=512,
block_size=512,
split="train",
text_column="text",
):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
elif data == "wikitext":
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
elif data == "bos":
dataset = load_dataset("schnell18/branch-of-science", split="train")
elif data == "c4":
dataset = load_dataset(
"allenai/c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
download_mode="reuse_dataset_if_exists",
)
else:
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
samples = []
n_run = 0
for data in dataset:
if isinstance(data, list):
line_encoded = data
else:
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
if sample.numel() == 0:
continue
samples.append(sample)
n_run += 1
if n_run == n_samples:
break
# now concatenate all samples and split according to block size
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size
return [
cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split)
]
class SensitiveLayerFinder:
def __init__(
self,
model,
model_name,
w_bit,
group_size,
tokenizer,
calib_data="pileval",
split="train",
text_column="text",
quant_method="hqq",
) -> None:
self.model = model
self.model_name = model_name
self.tokenizer = tokenizer
self.w_bit = w_bit
self.group_size = group_size
self.calib_data = calib_data
self.split = split
self.text_column = text_column
self.quant_method = quant_method
self.layers, self.module_kwargs, self.inps = self.init_quant()
if self.quant_method == "rtn":
self.quant_func = functools.partial(quant_hqq, optimize=False)
elif self.quant_method == "hqq":
self.quant_func = functools.partial(quant_hqq, optimize=True)
elif self.quant_method == "bnb":
self.quant_func = quant_nf4
@torch.no_grad()
def measure(self, csv_fp):
dikts = []
cfg = f"b{self.w_bit}g{self.group_size}"
for i in tqdm(
range(len(self.layers)), desc=f"{self.model_name}-{cfg}-{self.calib_data}"
):
# Move module and inputs to correct device
common_device = next(self.layers[i].parameters()).device
if common_device is None or str(common_device) == "cpu":
best_device = get_best_device(i)
self.layers[i] = self.layers[i].to(best_device)
common_device = next(self.layers[i].parameters()).device
if self.module_kwargs.get("position_ids") is not None:
self.module_kwargs["position_ids"] = self.module_kwargs[
"position_ids"
].to(common_device)
if self.module_kwargs.get("attention_mask") is not None:
self.module_kwargs["attention_mask"] = self.module_kwargs[
"attention_mask"
].to(common_device)
self.inps = self.inps.to(common_device)
named_linears = get_named_linears(self.layers[i])
input_feat = self._get_input_feat(self.layers[i], named_linears)
clear_memory()
module_config = get_layers_for_scaling(
self.layers[i], input_feat, self.module_kwargs
)
for layer in module_config:
part = layer.pop("part", "Unknown")
mse = self._measure_layer_sensitivity(self.layers[i], **layer)
dikts.append(
{
"dataset": self.calib_data,
"part": part,
"model": self.model_name,
"nbits": self.w_bit,
"group_size": self.group_size,
"layer": i,
"sensitivity": mse,
}
)
del module_config
del input_feat
clear_memory()
return dikts
def _measure_layer_sensitivity(
self,
module,
prev_op,
layers: List[nn.Linear],
inp: torch.Tensor,
module2inspect=None,
kwargs={},
):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
# Put x on the right device
inp = inp.to(next(module2inspect.parameters()).device)
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
fp16_output = module2inspect(inp, **module_kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
# Quantize the weights
for fc in layers:
# call quantization function
fc.weight.data = self.quant_func(
fc.weight.data, self.w_bit, self.group_size
)
# W * X
int_w_output = module2inspect(inp, **module_kwargs)
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
# compute mean squared error (L2 norm)
mse = (fp16_output - int_w_output).float().pow(2).mean().item()
del fp16_output
del int_w_output
clear_memory()
return mse
def init_quant(self, n_samples=128, seqlen=512):
modules = self.model.model.layers
samples = get_calib_dataset(
data=self.calib_data,
tokenizer=self.tokenizer,
n_samples=n_samples,
block_size=seqlen,
split=self.split,
text_column=self.text_column,
)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
best_device = get_best_device()
modules[0] = modules[0].to(best_device)
self.model.model.embed_tokens = self.model.model.embed_tokens.to(best_device)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
# assume first input to forward is hidden states
if len(args) > 0:
hidden_states = args[0]
del args
else:
first_key = list(kwargs.keys())[0]
hidden_states = kwargs.pop(first_key)
inps.append(hidden_states)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
modules[0] = modules[0].module # restore
# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
# Pop the input_ids as they are not needed at all.
layer_kwargs.pop("input_ids")
del samples
inps = inps[0]
modules[0] = modules[0].cpu()
self.model.model.embed_tokens = self.model.model.embed_tokens.to("cpu")
clear_memory()
if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
best_device
)
return modules, layer_kwargs, inps
def _get_input_feat(self, layer, named_linears):
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
)
)
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
# Sanitize the kwargs in case we use transformers version that contains
# kwargs that are not handled by the module.
# Useful for trust_remote_code models.
module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)
self.inps = layer(self.inps, **module_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat
def _sanitize_kwargs(self, inputs_kwargs, module):
"""
Remove the arguments that are not supported in the module's
forward pass to avoid breaking behaviour between different versions
of transformers.
Args:
inputs_kwargs (`dict`):
The input dictionary to pass to the model layer
module (`torch.nn.Module`):
Target module to quantize.
"""
module_signature = inspect.signature(module.forward).parameters
sanitized_kwargs = {}
for k, v in inputs_kwargs.items():
if k in module_signature:
sanitized_kwargs[k] = v
return sanitized_kwargs
def measure_sensitivity(models, quant_method, cfgs, calib_datasets, csv_fp):
pat = re.compile(r"b(\d)g(\d+)")
bgs = []
for cfg in cfgs:
m = re.match(pat, cfg)
if m:
bgs.append((int(m.group(1)), int(m.group(2))))
dikts = []
for ds in calib_datasets:
for bg in bgs:
for model_path in models:
short_name = model_path.split("/")[-1]
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
offload_state_dict=False,
max_memory={0: "18GiB", "cpu": "60GiB"},
)
tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False)
finder = SensitiveLayerFinder(
model,
short_name,
bg[0],
bg[1],
tokenizer,
ds,
quant_method=quant_method,
)
dikts.extend(finder.measure(csv_fp))
clear_memory()
time.sleep(2)
df = pd.DataFrame(dikts)
df.to_csv(csv_fp, index=False)
# with profile(
# activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True
# ) as prof:
# finder.identify(csv_fp)
# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))