← 返回首页
🧠

显存管理:LLM推理和训练中的显存优化技术

📂 llm ⏱ 5 min 835 words

显存管理:LLM推理和训练中的显存优化技术

显存基础概念

显存组成分析

import torch
import subprocess

class GPUMemoryAnalyzer:
    def __init__(self, device_id: int = 0):
        self.device_id = device_id
    
    def get_memory_info(self) -> dict:
        """获取GPU显存信息"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated(self.device_id) / 1024**3
            reserved = torch.cuda.memory_reserved(self.device_id) / 1024**3
            max_allocated = torch.cuda.max_memory_allocated(self.device_id) / 1024**3
            
            return {
                'allocated_gb': allocated,
                'reserved_gb': reserved,
                'max_allocated_gb': max_allocated,
                'total_gb': self.get_total_memory()
            }
        return None
    
    def get_total_memory(self) -> float:
        """获取GPU总显存"""
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=memory.total', '--format=csv'],
            capture_output=True, text=True
        )
        
        lines = result.stdout.strip().split('\n')
        if len(lines) > 1:
            value = lines[1].split()[0]
            return int(value) / 1024  # MB to GB
        
        return 0.0
    
    def profile_model(self, model, input_shape: tuple):
        """分析模型显存使用"""
        torch.cuda.reset_peak_memory_stats(self.device_id)
        
        # 创建输入
        dummy_input = torch.randint(0, 1000, input_shape).to(f'cuda:{self.device_id}')
        
        # 前向传播
        with torch.no_grad():
            output = model(dummy_input)
        
        # 记录显存使用
        memory_after_forward = torch.cuda.memory_allocated(self.device_id)
        
        # 释放
        del output, dummy_input
        torch.cuda.empty_cache()
        
        return {
            'model_size_gb': sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3,
            'forward_pass_gb': memory_after_forward / 1024**3
        }

显存占用分析

class MemoryBreakdown:
    """分析LLM各组件的显存占用"""
    
    def __init__(self, model_config):
        self.config = model_config
    
    def calculate_memory_usage(self, batch_size: int = 1, seq_length: int = 2048) -> dict:
        """计算显存使用"""
        # 模型参数
        param_memory = self.config.num_params * 2  # FP16
        
        # KV Cache
        kv_cache_memory = (
            2 *  # K和V
            self.config.num_layers *
            self.config.num_heads *
            self.config.head_dim *
            batch_size *
            seq_length *
            2  # FP16
        )
        
        # 激活值(估算)
        activation_memory = (
            batch_size *
            seq_length *
            self.config.hidden_size *
            self.config.num_layers *
            4  # FP32 for backward pass
        )
        
        # 优化器状态(训练时)
        optimizer_memory = param_memory * 2  # Adam需要2倍参数空间
        
        return {
            'parameters_gb': param_memory / 1024**3,
            'kv_cache_gb': kv_cache_memory / 1024**3,
            'activations_gb': activation_memory / 1024**3,
            'optimizer_gb': optimizer_memory / 1024**3,
            'total_training_gb': (param_memory + kv_cache_memory + activation_memory + optimizer_memory) / 1024**3,
            'total_inference_gb': (param_memory + kv_cache_memory) / 1024**3
        }

推理显存优化

量化减少显存

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

class InferenceOptimizer:
    def __init__(self, model_name: str):
        self.model_name = model_name
    
    def load_int8(self):
        """INT8量化加载"""
        config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_threshold=6.0
        )
        
        model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=config,
            device_map="auto"
        )
        
        return model
    
    def load_int4(self):
        """INT4量化加载"""
        config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        
        model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=config,
            device_map="auto"
        )
        
        return model
    
    def compare_memory_usage(self):
        """比较不同精度的显存使用"""
        results = {}
        
        # FP16
        model_fp16 = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16
        )
        results['fp16'] = self.get_model_memory(model_fp16)
        del model_fp16
        
        # INT8
        model_int8 = self.load_int8()
        results['int8'] = self.get_model_memory(model_int8)
        del model_int8
        
        # INT4
        model_int4 = self.load_int4()
        results['int4'] = self.get_model_memory(model_int4)
        
        return results
    
    def get_model_memory(self, model) -> float:
        """获取模型显存占用"""
        return sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3

KV Cache优化

class KVCacheManager:
    def __init__(self, max_batch_size: int, max_seq_length: int, 
                 num_layers: int, num_heads: int, head_dim: int):
        self.max_batch_size = max_batch_size
        self.max_seq_length = max_seq_length
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        self.cache = None
        self.allocate_cache()
    
    def allocate_cache(self):
        """预分配KV Cache"""
        cache_shape = (
            2,  # K和V
            self.num_layers,
            self.max_batch_size,
            self.num_heads,
            self.max_seq_length,
            self.head_dim
        )
        
        self.cache = torch.zeros(
            cache_shape,
            dtype=torch.float16,
            device='cuda'
        )
        
        print(f"KV Cache allocated: {self.cache.numel() * 2 / 1024**3:.2f} GB")
    
    def update_cache(self, batch_idx: int, seq_idx: int, 
                     key_states: torch.Tensor, value_states: torch.Tensor):
        """更新KV Cache"""
        self.cache[0, :, batch_idx, :, seq_idx, :] = key_states
        self.cache[1, :, batch_idx, :, seq_idx, :] = value_states
    
    def trim_cache(self, batch_idx: int, new_length: int):
        """裁剪KV Cache"""
        self.cache[:, :, batch_idx, :, new_length:, :] = 0
    
    def get_memory_usage(self) -> dict:
        """获取显存使用情况"""
        total = self.cache.numel() * 2  # FP16
        used = (self.cache != 0).sum().item() * 2
        
        return {
            'total_bytes': total,
            'used_bytes': used,
            'utilization': used / total if total > 0 else 0
        }

训练显存优化

梯度检查点

from torch.utils.checkpoint import checkpoint

class GradientCheckpointing:
    def __init__(self, model):
        self.model = model
        self.checkpoint_interval = 4  # 每4层做一次检查点
    
    def apply_checkpointing(self):
        """应用梯度检查点"""
        for i, layer in enumerate(self.model.layers):
            if i % self.checkpoint_interval == 0:
                self.model.layers[i] = self.wrap_with_checkpoint(layer)
    
    def wrap_with_checkpoint(self, layer):
        """用检查点包装层"""
        def wrapped(*args, **kwargs):
            return checkpoint(layer, *args, use_reentrant=False, **kwargs)
        return wrapped
    
    def estimate_savings(self, original_memory: float) -> dict:
        """估算节省的显存"""
        # 梯度检查点可以节省约60-70%的激活值显存
        activation_savings = original_memory * 0.65
        
        # 但会增加约30%的计算时间
        compute_overhead = 0.3
        
        return {
            'memory_saved_gb': activation_savings,
            'compute_overhead': compute_overhead
        }

混合精度训练

class MixedPrecisionTrainer:
    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer
        self.scaler = torch.cuda.amp.GradScaler()
    
    def train_step(self, batch):
        """混合精度训练步骤"""
        self.optimizer.zero_grad()
        
        # 自动混合精度
        with torch.cuda.amp.autocast():
            outputs = self.model(batch['input_ids'])
            loss = self.compute_loss(outputs, batch['labels'])
        
        # 缩放梯度
        self.scaler.scale(loss).backward()
        
        # 更新参数
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        return loss.item()
    
    def estimate_memory_savings(self) -> dict:
        """估算显存节省"""
        # FP16参数占用是FP32的一半
        param_savings = 0.5
        
        # 梯度也是FP16
        gradient_savings = 0.5
        
        # 优化器状态仍然是FP32
        optimizer_overhead = 0.0
        
        return {
            'parameter_savings': f"{param_savings * 100}%",
            'gradient_savings': f"{gradient_savings * 100}%",
            'overall_savings': f"{(param_savings + gradient_savings) / 2 * 100}%"
        }

显存分片(ZeRO)

class ZeROPartitioner:
    """ZeRO显存分片策略"""
    
    def __init__(self, model, num_gpus: int):
        self.model = model
        self.num_gpus = num_gpus
    
    def partition_parameters(self) -> dict:
        """分区模型参数"""
        total_params = sum(p.numel() for p in self.model.parameters())
        params_per_gpu = total_params // self.num_gpus
        
        return {
            'total_parameters': total_params,
            'parameters_per_gpu': params_per_gpu,
            'memory_per_gpu_gb': params_per_gpu * 2 / 1024**3  # FP16
        }
    
    def get_zeRO_stages(self) -> dict:
        """获取ZeRO各阶段的显存使用"""
        total_params = sum(p.numel() for p in self.model.parameters())
        
        stages = {
            'stage_0': {
                'name': '无分片',
                'params_per_gpu': total_params,
                'optimizer_per_gpu': total_params * 8  # FP32 + momentum + variance
            },
            'stage_1': {
                'name': '优化器状态分片',
                'params_per_gpu': total_params,
                'optimizer_per_gpu': total_params * 8 // self.num_gpus
            },
            'stage_2': {
                'name': '优化器+梯度分片',
                'params_per_gpu': total_params,
                'optimizer_per_gpu': total_params * 8 // self.num_gpus,
                'gradients_per_gpu': total_params * 2 // self.num_gpus
            },
            'stage_3': {
                'name': '全部分片',
                'params_per_gpu': total_params // self.num_gpus,
                'optimizer_per_gpu': total_params * 8 // self.num_gpus,
                'gradients_per_gpu': total_params * 2 // self.num_gpus
            }
        }
        
        return stages

显存监控工具

class MemoryMonitor:
    def __init__(self):
        self.snapshots = []
    
    def take_snapshot(self, label: str = ""):
        """拍摄显存快照"""
        torch.cuda.synchronize()
        
        snapshot = {
            'label': label,
            'timestamp': time.time(),
            'allocated': torch.cuda.memory_allocated() / 1024**3,
            'reserved': torch.cuda.memory_reserved() / 1024**3,
            'max_allocated': torch.cuda.max_memory_allocated() / 1024**3
        }
        
        self.snapshots.append(snapshot)
        return snapshot
    
    def compare_snapshots(self, idx1: int, idx2: int) -> dict:
        """比较两个快照"""
        s1 = self.snapshots[idx1]
        s2 = self.snapshots[idx2]
        
        return {
            'allocated_diff': s2['allocated'] - s1['allocated'],
            'reserved_diff': s2['reserved'] - s1['reserved'],
            'time_diff': s2['timestamp'] - s1['timestamp']
        }
    
    def detect_leaks(self) -> list:
        """检测显存泄漏"""
        leaks = []
        
        for i in range(1, len(self.snapshots)):
            diff = self.compare_snapshots(i-1, i)
            
            if diff['allocated_diff'] > 0.1:  # 超过100MB增长
                leaks.append({
                    'between': i-1,
                    'growth_gb': diff['allocated_diff'],
                    'possible_cause': '显存泄漏'
                })
        
        return leaks

最佳实践

  1. 提前规划:计算好显存需求再开始训练
  2. 渐进加载:大模型使用device_map="auto"自动分配
  3. 及时释放:不再需要的张量及时删除
  4. 批量分析:使用nvidia-smi和PyTorch工具监控
  5. 分片训练:多GPU使用ZeRO分片
  6. 量化推理:推理时使用INT8/INT4量化