← 返回首页
🧠

量化完全指南

📂 llm ⏱ 4 min 638 words

--- title: "量化完全指南" description: "全面介绍大模型量化技术,包括PTQ、QAT、GPTQ、AWQ等主流方法,实现模型压缩与推理加速。" tags: ["量化", "PTQ", "QAT", "INT8", "INT4", "GPTQ", "AWQ"] category: "llm" icon: "🧠"

量化完全指南

什么是模型量化

量化是将模型权重和激活值从高精度(FP32/FP16)转换为低精度(INT8/INT4)表示的技术。通过量化可以显著减少模型大小、降低显存占用、提升推理速度,同时保持可接受的精度。

量化基础

均匀量化公式

import torch
import numpy as np

class UniformQuantizer:
    """均匀量化器"""
    
    def __init__(self, bits: int = 8):
        self.bits = bits
        self.qmin = -(2 ** (bits - 1))
        self.qmax = 2 ** (bits - 1) - 1
    
    def compute_scale_zero_point(self, tensor: torch.Tensor):
        """计算量化参数"""
        min_val = tensor.min()
        max_val = tensor.max()
        
        scale = (max_val - min_val) / (self.qmax - self.qmin)
        zero_point = self.qmin - min_val / scale
        
        return scale, zero_point
    
    def quantize(self, tensor: torch.Tensor) -> torch.Tensor:
        """量化张量"""
        scale, zero_point = self.compute_scale_zero_point(tensor)
        quantized = torch.round(tensor / scale + zero_point)
        quantized = torch.clamp(quantized, self.qmin, self.qmax)
        return quantized.to(torch.int8), scale, zero_point
    
    def dequantize(self, quantized: torch.Tensor, scale, zero_point) -> torch.Tensor:
        """反量化"""
        return (quantized.float() - zero_point) * scale

# 使用示例
quantizer = UniformQuantizer(bits=8)
weights = torch.randn(1024, 1024)
q_weights, scale, zp = quantizer.quantize(weights)
print(f"量化后大小: {q_weights.element_size() * q_weights.nelement() / 1024 / 1024:.2f}MB")

PTQ(训练后量化)

简单PTQ

import torch
from torch.quantization import quantize_dynamic

def ptq_int8(model):
    """INT8训练后量化"""
    quantized_model = quantize_dynamic(
        model,
        {torch.nn.Linear},  # 量化Linear层
        dtype=torch.qint8,
    )
    return quantized_model

def ptq_int4(model):
    """INT4训练后量化(需要校准数据)"""
    from torch.ao.quantization import get_default_qconfig
    
    model.eval()
    model.qconfig = get_default_qconfig('x86')
    
    # 准备校准
    prepared_model = torch.ao.quantization.prepare(model)
    
    # 校准(使用少量数据)
    calibration_data = [torch.randn(1, 128) for _ in range(100)]
    with torch.no_grad():
        for data in calibration_data:
            prepared_model(data)
    
    # 转换为量化模型
    quantized_model = torch.ao.quantization.convert(prepared_model)
    return quantized_model

SmoothQuant

class SmoothQuant:
    """平滑量化:平衡激活和权重的量化难度"""
    
    def __init__(self, alpha: float = 0.5):
        self.alpha = alpha
    
    def smooth_activation(self, activation: torch.Tensor, 
                          weight: torch.Tensor) -> tuple:
        """平滑激活值分布"""
        # 计算缩放因子
        act_scale = activation.abs().mean(dim=0)
        weight_scale = weight.abs().max(dim=0)[0]
        
        # 计算平滑因子
        smooth_factor = torch.pow(act_scale, self.alpha) / torch.pow(weight_scale, 1 - self.alpha)
        
        # 应用平滑
        smoothed_activation = activation / smooth_factor
        smoothed_weight = weight * smooth_factor.unsqueeze(1)
        
        return smoothed_activation, smoothed_weight

QAT(量化感知训练)

class QuantizationAwareTraining:
    """量化感知训练"""
    
    def __init__(self, model, bits=8):
        self.model = model
        self.bits = bits
        self.fake_quantizers = {}
    
    def prepare(self):
        """准备QAT"""
        from torch.ao.quantization import FakeQuantize
        
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                self.fake_quantizers[name] = FakeQuantize.with_args(
                    observer=torch.ao.quantization.MinMaxObserver,
                    quant_min=-(2 ** (self.bits - 1)),
                    quant_max=2 ** (self.bits - 1) - 1,
                )
    
    def train_step(self, input_ids, labels):
        """QAT训练步骤"""
        # 前向传播(带伪量化)
        output = self.model(input_ids)
        
        # 计算损失
        loss = torch.nn.functional.cross_entropy(
            output.logits.view(-1, output.logits.size(-1)),
            labels.view(-1)
        )
        
        return loss
    
    def convert(self):
        """转换为实际量化模型"""
        print("转换为量化模型")
        return self.model

GPTQ量化

class GPTQQuantizer:
    """GPTQ量化:逐层量化,保留精度"""
    
    def __init__(self, bits: int = 4, group_size: int = 128):
        self.bits = bits
        self.group_size = group_size
    
    def quantize_layer(self, weight: torch.Tensor, 
                       calibration_data: torch.Tensor) -> tuple:
        """量化单个层"""
        # 计算Hessian矩阵
        hessian = calibration_data.T @ calibration_data
        hessian_inv = torch.linalg.inv(hessian + 1e-4 * torch.eye(hessian.shape[0]))
        
        # 逐列量化
        quantized_weight = torch.zeros_like(weight)
        scales = []
        zero_points = []
        
        for col in range(weight.shape[1]):
            w_col = weight[:, col]
            
            # 计算量化参数
            scale = w_col.abs().max() / (2 ** (self.bits - 1) - 1)
            zero_point = 0
            scales.append(scale)
            zero_points.append(zero_point)
            
            # 量化
            q_col = torch.round(w_col / scale)
            q_col = torch.clamp(q_col, -(2 ** (self.bits - 1)), 2 ** (self.bits - 1) - 1)
            quantized_weight[:, col] = q_col * scale
            
            # 更新剩余权重补偿误差
            error = w_col - q_col * scale
            weight[:, col+1:] -= error.unsqueeze(1) * hessian_inv[col, col+1:]
        
        return quantized_weight, scales, zero_points
    
    def quantize_model(self, model, calibration_loader):
        """量化整个模型"""
        print("开始GPTQ量化...")
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                print(f"量化层: {name}")
                # 获取校准数据
                for batch in calibration_loader:
                    calibration_data = batch
                    break
                
                module.weight.data, scales, zps = self.quantize_layer(
                    module.weight.data, calibration_data
                )
        
        return model

AWQ量化

class AWQQuantizer:
    """AWQ量化:激活感知的权重量化"""
    
    def __init__(self, bits: int = 4, group_size: int = 128):
        self.bits = bits
        self.group_size = group_size
    
    def find_best_scale(self, weight: torch.Tensor, 
                        activation: torch.Tensor) -> torch.Tensor:
        """寻找最优缩放因子"""
        # 搜索最优scale
        best_loss = float('inf')
        best_scale = None
        
        for s in torch.linspace(0.1, 10.0, 100):
            scaled_weight = weight * s.unsqueeze(1)
            q_weight = self._quantize(scaled_weight)
            loss = ((q_weight / s.unsqueeze(1) - weight) ** 2).mean()
            
            if loss < best_loss:
                best_loss = loss
                best_scale = s
        
        return best_scale
    
    def quantize_layer(self, weight: torch.Tensor, 
                       activation: torch.Tensor) -> torch.Tensor:
        """AWQ量化单层"""
        scale = self.find_best_scale(weight, activation)
        
        # 应用scale后量化
        scaled_weight = weight * scale.unsqueeze(1)
        quantized = self._quantize(scaled_weight)
        
        return quantized / scale.unsqueeze(1)
    
    def _quantize(self, tensor: torch.Tensor) -> torch.Tensor:
        """INT4量化"""
        max_val = tensor.abs().max()
        scale = max_val / (2 ** (self.bits - 1) - 1)
        return torch.round(tensor / scale).clamp(-(2 ** (self.bits - 1)), 2 ** (self.bits - 1) - 1) * scale

量化效果对比

quantization_results = {
    'FP16': {'size_gb': 14.0, 'perplexity': 5.28, 'speedup': 1.0},
    'INT8_PTQ': {'size_gb': 7.0, 'perplexity': 5.31, 'speedup': 1.5},
    'INT4_GPTQ': {'size_gb': 3.5, 'perplexity': 5.45, 'speedup': 2.0},
    'INT4_AWQ': {'size_gb': 3.5, 'perplexity': 5.40, 'speedup': 2.1},
    'INT4_QAT': {'size_gb': 3.5, 'perplexity': 5.35, 'speedup': 1.9},
}

print("7B模型量化对比:")
print(f"{'方法':<12} {'大小(GB)':<10} {'困惑度':<10} {'加速比':<10}")
for method, result in quantization_results.items():
    print(f"{method:<12} {result['size_gb']:<10} {result['perplexity']:<10} {result['speedup']:<10}")

最佳实践