| import functools |
| import gc |
| import inspect |
| import os |
| import re |
| import time |
| from collections import defaultdict |
| from typing import List |
|
|
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from bitsandbytes.functional import dequantize_nf4 |
| from bitsandbytes.nn import Params4bit |
| from datasets import load_dataset |
| from hqq.core.quantize import Quantizer as hQuant |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from transformers.models.llama.modeling_llama import LlamaDecoderLayer |
|
|
|
|
| |
| def quant_nf4(tensor, nbits=4, group_size=64): |
| qw = Params4bit(tensor, blocksize=group_size, quant_type="nf4").to(tensor.device) |
| return dequantize_nf4(qw.data, qw.quant_state) |
|
|
|
|
| def quant_hqq(tensor, nbits, group_size=64, optimize=True): |
| wq, meta = hQuant.quantize( |
| tensor, nbits=nbits, group_size=group_size, optimize=optimize |
| ) |
| return hQuant.dequantize(wq, meta) |
|
|
|
|
| def get_named_linears(module): |
| return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} |
|
|
|
|
| def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs): |
| layers = [] |
|
|
| |
| layers.append( |
| dict( |
| part="attn_in", |
| prev_op=module.input_layernorm, |
| layers=[ |
| module.self_attn.q_proj, |
| module.self_attn.k_proj, |
| module.self_attn.v_proj, |
| ], |
| inp=input_feat["self_attn.q_proj"], |
| module2inspect=module.self_attn, |
| kwargs=module_kwargs, |
| ) |
| ) |
|
|
| |
| |
| if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: |
| layers.append( |
| dict( |
| part="attn_out", |
| prev_op=module.self_attn.v_proj, |
| layers=[module.self_attn.o_proj], |
| inp=input_feat["self_attn.o_proj"], |
| ) |
| ) |
|
|
| |
| layers.append( |
| dict( |
| part="mlp_gate", |
| prev_op=module.post_attention_layernorm, |
| layers=[module.mlp.gate_proj, module.mlp.up_proj], |
| inp=input_feat["mlp.gate_proj"], |
| module2inspect=module.mlp, |
| ) |
| ) |
|
|
| |
| layers.append( |
| dict( |
| part="mlp_down", |
| prev_op=module.mlp.up_proj, |
| layers=[module.mlp.down_proj], |
| inp=input_feat["mlp.down_proj"], |
| ) |
| ) |
|
|
| return layers |
|
|
|
|
| def clear_memory(weight=None): |
| if weight is not None: |
| del weight |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
|
|
| def get_best_device(idx=None): |
| if os.environ.get("USE_CPU_FOR_SENSITIVITY", None) == "1": |
| return "cpu" |
| if torch.cuda.is_available(): |
| if idx is None: |
| return "cuda:0" |
| else: |
| return "cuda:" + str(idx % torch.cuda.device_count()) |
| elif torch.backends.mps.is_available(): |
| return "mps" |
| else: |
| return "cpu" |
|
|
|
|
| def get_calib_dataset( |
| data: str = "pileval", |
| tokenizer=None, |
| n_samples=512, |
| block_size=512, |
| split="train", |
| text_column="text", |
| ): |
| if data == "pileval": |
| dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") |
| elif data == "wikitext": |
| dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation") |
| elif data == "bos": |
| dataset = load_dataset("schnell18/branch-of-science", split="train") |
| elif data == "c4": |
| dataset = load_dataset( |
| "allenai/c4", |
| data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, |
| split="validation", |
| download_mode="reuse_dataset_if_exists", |
| ) |
| else: |
| dataset = load_dataset(data, split=split) |
|
|
| dataset = dataset.shuffle(seed=42) |
|
|
| samples = [] |
| n_run = 0 |
| for data in dataset: |
| if isinstance(data, list): |
| line_encoded = data |
| else: |
| line = data[text_column] |
| line = line.strip() |
| line_encoded = tokenizer.encode(line) |
| if len(line_encoded) > 512: |
| continue |
| sample = torch.tensor([line_encoded]) |
| if sample.numel() == 0: |
| continue |
| samples.append(sample) |
| n_run += 1 |
| if n_run == n_samples: |
| break |
| |
| cat_samples = torch.cat(samples, dim=1) |
| n_split = cat_samples.shape[1] // block_size |
| return [ |
| cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) |
| ] |
|
|
|
|
| class SensitiveLayerFinder: |
| def __init__( |
| self, |
| model, |
| model_name, |
| w_bit, |
| group_size, |
| tokenizer, |
| calib_data="pileval", |
| split="train", |
| text_column="text", |
| quant_method="hqq", |
| ) -> None: |
| self.model = model |
| self.model_name = model_name |
| self.tokenizer = tokenizer |
| self.w_bit = w_bit |
| self.group_size = group_size |
| self.calib_data = calib_data |
| self.split = split |
| self.text_column = text_column |
| self.quant_method = quant_method |
| self.layers, self.module_kwargs, self.inps = self.init_quant() |
| if self.quant_method == "rtn": |
| self.quant_func = functools.partial(quant_hqq, optimize=False) |
| elif self.quant_method == "hqq": |
| self.quant_func = functools.partial(quant_hqq, optimize=True) |
| elif self.quant_method == "bnb": |
| self.quant_func = quant_nf4 |
|
|
| @torch.no_grad() |
| def measure(self, csv_fp): |
| dikts = [] |
| cfg = f"b{self.w_bit}g{self.group_size}" |
| for i in tqdm( |
| range(len(self.layers)), desc=f"{self.model_name}-{cfg}-{self.calib_data}" |
| ): |
| |
| common_device = next(self.layers[i].parameters()).device |
| if common_device is None or str(common_device) == "cpu": |
| best_device = get_best_device(i) |
| self.layers[i] = self.layers[i].to(best_device) |
| common_device = next(self.layers[i].parameters()).device |
|
|
| if self.module_kwargs.get("position_ids") is not None: |
| self.module_kwargs["position_ids"] = self.module_kwargs[ |
| "position_ids" |
| ].to(common_device) |
|
|
| if self.module_kwargs.get("attention_mask") is not None: |
| self.module_kwargs["attention_mask"] = self.module_kwargs[ |
| "attention_mask" |
| ].to(common_device) |
|
|
| self.inps = self.inps.to(common_device) |
| named_linears = get_named_linears(self.layers[i]) |
| input_feat = self._get_input_feat(self.layers[i], named_linears) |
| clear_memory() |
|
|
| module_config = get_layers_for_scaling( |
| self.layers[i], input_feat, self.module_kwargs |
| ) |
|
|
| for layer in module_config: |
| part = layer.pop("part", "Unknown") |
| mse = self._measure_layer_sensitivity(self.layers[i], **layer) |
| dikts.append( |
| { |
| "dataset": self.calib_data, |
| "part": part, |
| "model": self.model_name, |
| "nbits": self.w_bit, |
| "group_size": self.group_size, |
| "layer": i, |
| "sensitivity": mse, |
| } |
| ) |
|
|
| del module_config |
| del input_feat |
| clear_memory() |
| return dikts |
|
|
| def _measure_layer_sensitivity( |
| self, |
| module, |
| prev_op, |
| layers: List[nn.Linear], |
| inp: torch.Tensor, |
| module2inspect=None, |
| kwargs={}, |
| ): |
| if module2inspect is None: |
| assert len(layers) == 1 |
| module2inspect = layers[0] |
|
|
| |
| inp = inp.to(next(module2inspect.parameters()).device) |
|
|
| module_kwargs = self._sanitize_kwargs(kwargs, module2inspect) |
|
|
| fp16_output = module2inspect(inp, **module_kwargs) |
| if isinstance(fp16_output, tuple): |
| fp16_output = fp16_output[0] |
|
|
| |
| for fc in layers: |
| |
| fc.weight.data = self.quant_func( |
| fc.weight.data, self.w_bit, self.group_size |
| ) |
|
|
| |
| int_w_output = module2inspect(inp, **module_kwargs) |
| if isinstance(int_w_output, tuple): |
| int_w_output = int_w_output[0] |
|
|
| |
| mse = (fp16_output - int_w_output).float().pow(2).mean().item() |
| del fp16_output |
| del int_w_output |
| clear_memory() |
| return mse |
|
|
| def init_quant(self, n_samples=128, seqlen=512): |
| modules = self.model.model.layers |
| samples = get_calib_dataset( |
| data=self.calib_data, |
| tokenizer=self.tokenizer, |
| n_samples=n_samples, |
| block_size=seqlen, |
| split=self.split, |
| text_column=self.text_column, |
| ) |
| samples = torch.cat(samples, dim=0) |
|
|
| inps = [] |
| layer_kwargs = {} |
|
|
| best_device = get_best_device() |
| modules[0] = modules[0].to(best_device) |
| self.model.model.embed_tokens = self.model.model.embed_tokens.to(best_device) |
|
|
| |
| |
| |
| class Catcher(nn.Module): |
| def __init__(self, module): |
| super().__init__() |
| self.module = module |
|
|
| def forward(self, *args, **kwargs): |
| |
| if len(args) > 0: |
| hidden_states = args[0] |
| del args |
| else: |
| first_key = list(kwargs.keys())[0] |
| hidden_states = kwargs.pop(first_key) |
|
|
| inps.append(hidden_states) |
| layer_kwargs.update(kwargs) |
| raise ValueError |
|
|
| |
| modules[0] = Catcher(modules[0]) |
| try: |
| self.model(samples.to(next(self.model.parameters()).device)) |
| except ValueError: |
| pass |
| modules[0] = modules[0].module |
|
|
| |
| |
| layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs) |
| |
| layer_kwargs.pop("input_ids") |
|
|
| del samples |
| inps = inps[0] |
|
|
| modules[0] = modules[0].cpu() |
| self.model.model.embed_tokens = self.model.model.embed_tokens.to("cpu") |
|
|
| clear_memory() |
|
|
| if layer_kwargs.get("attention_mask") is not None: |
| layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to( |
| best_device |
| ) |
|
|
| return modules, layer_kwargs, inps |
|
|
| def _get_input_feat(self, layer, named_linears): |
| |
| def cache_input_hook(m, x, y, name, feat_dict): |
| x = x[0] |
| x = x.detach().cpu() |
| feat_dict[name].append(x) |
|
|
| input_feat = defaultdict(list) |
| handles = [] |
|
|
| for name in named_linears: |
| handles.append( |
| named_linears[name].register_forward_hook( |
| functools.partial(cache_input_hook, name=name, feat_dict=input_feat) |
| ) |
| ) |
| self.inps = self.inps.to(next(layer.parameters()).device) |
| |
|
|
| |
| |
| |
| module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer) |
|
|
| self.inps = layer(self.inps, **module_kwargs)[0] |
| for h in handles: |
| h.remove() |
| |
| input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()} |
|
|
| return input_feat |
|
|
| def _sanitize_kwargs(self, inputs_kwargs, module): |
| """ |
| Remove the arguments that are not supported in the module's |
| forward pass to avoid breaking behaviour between different versions |
| of transformers. |
| |
| Args: |
| inputs_kwargs (`dict`): |
| The input dictionary to pass to the model layer |
| module (`torch.nn.Module`): |
| Target module to quantize. |
| """ |
| module_signature = inspect.signature(module.forward).parameters |
| sanitized_kwargs = {} |
| for k, v in inputs_kwargs.items(): |
| if k in module_signature: |
| sanitized_kwargs[k] = v |
| return sanitized_kwargs |
|
|
|
|
| def measure_sensitivity(models, quant_method, cfgs, calib_datasets, csv_fp): |
| pat = re.compile(r"b(\d)g(\d+)") |
| bgs = [] |
| for cfg in cfgs: |
| m = re.match(pat, cfg) |
| if m: |
| bgs.append((int(m.group(1)), int(m.group(2)))) |
| dikts = [] |
| for ds in calib_datasets: |
| for bg in bgs: |
| for model_path in models: |
| short_name = model_path.split("/")[-1] |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| offload_state_dict=False, |
| max_memory={0: "18GiB", "cpu": "60GiB"}, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) |
| finder = SensitiveLayerFinder( |
| model, |
| short_name, |
| bg[0], |
| bg[1], |
| tokenizer, |
| ds, |
| quant_method=quant_method, |
| ) |
| dikts.extend(finder.measure(csv_fp)) |
| clear_memory() |
| time.sleep(2) |
|
|
| df = pd.DataFrame(dikts) |
| df.to_csv(csv_fp, index=False) |
|
|
|
|
| |
| |
| |
| |
| |
|
|