| """ |
| Mixed-Precision Quantization Script for Small Language Models |
| Supports selective quantization of different model components with configurable bitwidths. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
| import argparse |
| import os |
| import json |
| from pathlib import Path |
| from typing import Dict, Optional, Tuple |
| import time |
|
|
| class MixedPrecisionQuantizer: |
| """ |
| Quantizes model components with different precision levels. |
| Supports more aggressive quantization for attention layers while |
| preserving higher precision for FFN layers. |
| """ |
| |
| def __init__( |
| self, |
| model_name: str, |
| attention_bits: int = 4, |
| ffn_bits: int = 8, |
| embedding_bits: int = 8, |
| output_dir: str = "./quantized_models", |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| ): |
| self.model_name = model_name |
| self.attention_bits = attention_bits |
| self.ffn_bits = ffn_bits |
| self.embedding_bits = embedding_bits |
| self.output_dir = Path(output_dir) |
| self.device = device |
| |
| |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
| print(f"Initializing quantizer for {model_name}") |
| print(f"Attention layers: {attention_bits}-bit") |
| print(f"FFN layers: {ffn_bits}-bit") |
| print(f"Embeddings: {embedding_bits}-bit") |
| print(f"Device: {device}") |
| |
| def load_model(self) -> Tuple[nn.Module, AutoTokenizer]: |
| """Load the pretrained model and tokenizer.""" |
| print(f"\nLoading model: {self.model_name}") |
| start_time = time.time() |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| self.model_name, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True |
| ) |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| self.model_name, |
| trust_remote_code=True |
| ) |
| |
| load_time = time.time() - start_time |
| print(f"Model loaded in {load_time:.2f} seconds") |
| |
| |
| param_count = sum(p.numel() for p in model.parameters()) |
| param_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2) |
| print(f"Parameters: {param_count:,} ({param_size_mb:.2f} MB)") |
| |
| return model, tokenizer |
| |
| def quantize_linear_layer(self, layer: nn.Linear, bits: int) -> nn.Linear: |
| """ |
| Quantize a linear layer to specified bit width using symmetric quantization. |
| """ |
| if bits == 32: |
| return layer |
| |
| weight = layer.weight.data.clone() |
| |
| |
| qmin = -(2 ** (bits - 1)) |
| qmax = 2 ** (bits - 1) - 1 |
| |
| |
| |
| max_val = torch.max(torch.abs(weight), dim=1, keepdim=True)[0] |
| max_val = torch.clamp(max_val, min=1e-5) |
| scale = max_val / qmax |
| |
| |
| weight_q = torch.clamp(torch.round(weight / scale), qmin, qmax) |
| weight_dq = weight_q * scale |
| |
| |
| layer.weight.data = weight_dq.contiguous() |
| |
| |
| layer.weight_scale = scale |
| layer.quantized = True |
| layer.bits = bits |
| |
| return layer |
| |
| def identify_layer_type(self, name: str, module: nn.Module) -> str: |
| """ |
| Identify if a layer is part of attention, FFN, embedding, or other components. |
| """ |
| name_lower = name.lower() |
| |
| |
| attention_patterns = [ |
| 'attn', 'attention', 'q_proj', 'k_proj', 'v_proj', |
| 'qkv', 'query', 'key', 'value', 'o_proj', 'out_proj', |
| 'c_attn', 'c_proj' |
| ] |
| |
| |
| ffn_patterns = [ |
| 'mlp', 'ffn', 'fc', 'dense', 'intermediate', |
| 'gate_proj', 'up_proj', 'down_proj', 'w1', 'w2', 'w3' |
| ] |
| |
| |
| embedding_patterns = ['embed', 'wte', 'wpe', 'lm_head'] |
| |
| if any(pattern in name_lower for pattern in attention_patterns): |
| return 'attention' |
| elif any(pattern in name_lower for pattern in ffn_patterns): |
| return 'ffn' |
| elif any(pattern in name_lower for pattern in embedding_patterns): |
| return 'embedding' |
| else: |
| return 'other' |
| |
| def quantize_model(self, model: nn.Module) -> Tuple[nn.Module, Dict]: |
| """ |
| Apply mixed-precision quantization to the model. |
| """ |
| print("\nApplying mixed-precision quantization...") |
| start_time = time.time() |
| |
| stats = { |
| 'attention_layers': 0, |
| 'ffn_layers': 0, |
| 'embedding_layers': 0, |
| 'other_layers': 0, |
| 'total_quantized': 0 |
| } |
| |
| |
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Linear): |
| layer_type = self.identify_layer_type(name, module) |
| |
| |
| if layer_type == 'attention': |
| bits = self.attention_bits |
| stats['attention_layers'] += 1 |
| elif layer_type == 'ffn': |
| bits = self.ffn_bits |
| stats['ffn_layers'] += 1 |
| elif layer_type == 'embedding': |
| bits = self.embedding_bits |
| stats['embedding_layers'] += 1 |
| else: |
| bits = self.ffn_bits |
| stats['other_layers'] += 1 |
| |
| |
| self.quantize_linear_layer(module, bits) |
| stats['total_quantized'] += 1 |
| |
| quant_time = time.time() - start_time |
| print(f"\nQuantization completed in {quant_time:.2f} seconds") |
| print(f"Quantized layers breakdown:") |
| print(f" - Attention: {stats['attention_layers']} layers ({self.attention_bits}-bit)") |
| print(f" - FFN: {stats['ffn_layers']} layers ({self.ffn_bits}-bit)") |
| print(f" - Embedding: {stats['embedding_layers']} layers ({self.embedding_bits}-bit)") |
| print(f" - Other: {stats['other_layers']} layers ({self.ffn_bits}-bit)") |
| print(f" - Total quantized: {stats['total_quantized']} layers") |
| |
| return model, stats |
| |
| def save_quantized_model( |
| self, |
| model: nn.Module, |
| tokenizer: AutoTokenizer, |
| stats: Dict |
| ) -> str: |
| """Save the quantized model, tokenizer, and metadata.""" |
| |
| model_short_name = self.model_name.split('/')[-1] |
| quant_config = f"attn{self.attention_bits}_ffn{self.ffn_bits}_emb{self.embedding_bits}" |
| save_dir = self.output_dir / f"{model_short_name}_{quant_config}" |
| save_dir.mkdir(parents=True, exist_ok=True) |
| |
| print(f"\nSaving quantized model to: {save_dir}") |
| |
| |
| model.save_pretrained(save_dir) |
| |
| |
| tokenizer.save_pretrained(save_dir) |
| |
| |
| quantized_size_mb = sum( |
| p.numel() * p.element_size() for p in model.parameters() |
| ) / (1024 ** 2) |
| |
| |
| metadata = { |
| 'original_model': self.model_name, |
| 'quantization_config': { |
| 'attention_bits': self.attention_bits, |
| 'ffn_bits': self.ffn_bits, |
| 'embedding_bits': self.embedding_bits |
| }, |
| 'layer_stats': stats, |
| 'model_size_mb': quantized_size_mb, |
| 'quantization_timestamp': time.strftime('%Y-%m-%d %H:%M:%S') |
| } |
| |
| with open(save_dir / 'quantization_metadata.json', 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| print(f"Quantized model size: {quantized_size_mb:.2f} MB") |
| print(f"Metadata saved to: {save_dir / 'quantization_metadata.json'}") |
| |
| return str(save_dir) |
| |
| def run(self) -> str: |
| """Execute the full quantization pipeline.""" |
| print("=" * 80) |
| print("MIXED-PRECISION QUANTIZATION PIPELINE") |
| print("=" * 80) |
| |
| |
| model, tokenizer = self.load_model() |
| |
| |
| quantized_model, stats = self.quantize_model(model) |
| |
| |
| save_path = self.save_quantized_model(quantized_model, tokenizer, stats) |
| |
| print("\n" + "=" * 80) |
| print("QUANTIZATION COMPLETE") |
| print("=" * 80) |
| print(f"Saved to: {save_path}") |
| |
| return save_path |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Mixed-Precision Quantization for Small Language Models" |
| ) |
| parser.add_argument( |
| '--model_name', |
| type=str, |
| required=True, |
| help='HuggingFace model name or path' |
| ) |
| parser.add_argument( |
| '--attention_bits', |
| type=int, |
| default=4, |
| help='Bit width for attention layers (default: 4)' |
| ) |
| parser.add_argument( |
| '--ffn_bits', |
| type=int, |
| default=8, |
| help='Bit width for FFN layers (default: 8)' |
| ) |
| parser.add_argument( |
| '--embedding_bits', |
| type=int, |
| default=8, |
| help='Bit width for embedding layers (default: 8)' |
| ) |
| parser.add_argument( |
| '--output_dir', |
| type=str, |
| default='./quantized_models', |
| help='Output directory for quantized models' |
| ) |
| parser.add_argument( |
| '--device', |
| type=str, |
| default='cuda' if torch.cuda.is_available() else 'cpu', |
| help='Device to use (cuda/cpu)' |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| quantizer = MixedPrecisionQuantizer( |
| model_name=args.model_name, |
| attention_bits=args.attention_bits, |
| ffn_bits=args.ffn_bits, |
| embedding_bits=args.embedding_bits, |
| output_dir=args.output_dir, |
| device=args.device |
| ) |
| |
| |
| quantizer.run() |
|
|
|
|
| if __name__ == "__main__": |
| main() |