显存管理:LLM推理和训练中的显存优化技术
显存基础概念
显存组成分析
import torch
import subprocess
class GPUMemoryAnalyzer:
def __init__(self, device_id: int = 0):
self.device_id = device_id
def get_memory_info(self) -> dict:
"""获取GPU显存信息"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(self.device_id) / 1024**3
reserved = torch.cuda.memory_reserved(self.device_id) / 1024**3
max_allocated = torch.cuda.max_memory_allocated(self.device_id) / 1024**3
return {
'allocated_gb': allocated,
'reserved_gb': reserved,
'max_allocated_gb': max_allocated,
'total_gb': self.get_total_memory()
}
return None
def get_total_memory(self) -> float:
"""获取GPU总显存"""
result = subprocess.run(
['nvidia-smi', '--query-gpu=memory.total', '--format=csv'],
capture_output=True, text=True
)
lines = result.stdout.strip().split('\n')
if len(lines) > 1:
value = lines[1].split()[0]
return int(value) / 1024 # MB to GB
return 0.0
def profile_model(self, model, input_shape: tuple):
"""分析模型显存使用"""
torch.cuda.reset_peak_memory_stats(self.device_id)
# 创建输入
dummy_input = torch.randint(0, 1000, input_shape).to(f'cuda:{self.device_id}')
# 前向传播
with torch.no_grad():
output = model(dummy_input)
# 记录显存使用
memory_after_forward = torch.cuda.memory_allocated(self.device_id)
# 释放
del output, dummy_input
torch.cuda.empty_cache()
return {
'model_size_gb': sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3,
'forward_pass_gb': memory_after_forward / 1024**3
}
显存占用分析
class MemoryBreakdown:
"""分析LLM各组件的显存占用"""
def __init__(self, model_config):
self.config = model_config
def calculate_memory_usage(self, batch_size: int = 1, seq_length: int = 2048) -> dict:
"""计算显存使用"""
# 模型参数
param_memory = self.config.num_params * 2 # FP16
# KV Cache
kv_cache_memory = (
2 * # K和V
self.config.num_layers *
self.config.num_heads *
self.config.head_dim *
batch_size *
seq_length *
2 # FP16
)
# 激活值(估算)
activation_memory = (
batch_size *
seq_length *
self.config.hidden_size *
self.config.num_layers *
4 # FP32 for backward pass
)
# 优化器状态(训练时)
optimizer_memory = param_memory * 2 # Adam需要2倍参数空间
return {
'parameters_gb': param_memory / 1024**3,
'kv_cache_gb': kv_cache_memory / 1024**3,
'activations_gb': activation_memory / 1024**3,
'optimizer_gb': optimizer_memory / 1024**3,
'total_training_gb': (param_memory + kv_cache_memory + activation_memory + optimizer_memory) / 1024**3,
'total_inference_gb': (param_memory + kv_cache_memory) / 1024**3
}
推理显存优化
量化减少显存
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
class InferenceOptimizer:
def __init__(self, model_name: str):
self.model_name = model_name
def load_int8(self):
"""INT8量化加载"""
config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=config,
device_map="auto"
)
return model
def load_int4(self):
"""INT4量化加载"""
config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=config,
device_map="auto"
)
return model
def compare_memory_usage(self):
"""比较不同精度的显存使用"""
results = {}
# FP16
model_fp16 = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16
)
results['fp16'] = self.get_model_memory(model_fp16)
del model_fp16
# INT8
model_int8 = self.load_int8()
results['int8'] = self.get_model_memory(model_int8)
del model_int8
# INT4
model_int4 = self.load_int4()
results['int4'] = self.get_model_memory(model_int4)
return results
def get_model_memory(self, model) -> float:
"""获取模型显存占用"""
return sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3
KV Cache优化
class KVCacheManager:
def __init__(self, max_batch_size: int, max_seq_length: int,
num_layers: int, num_heads: int, head_dim: int):
self.max_batch_size = max_batch_size
self.max_seq_length = max_seq_length
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.cache = None
self.allocate_cache()
def allocate_cache(self):
"""预分配KV Cache"""
cache_shape = (
2, # K和V
self.num_layers,
self.max_batch_size,
self.num_heads,
self.max_seq_length,
self.head_dim
)
self.cache = torch.zeros(
cache_shape,
dtype=torch.float16,
device='cuda'
)
print(f"KV Cache allocated: {self.cache.numel() * 2 / 1024**3:.2f} GB")
def update_cache(self, batch_idx: int, seq_idx: int,
key_states: torch.Tensor, value_states: torch.Tensor):
"""更新KV Cache"""
self.cache[0, :, batch_idx, :, seq_idx, :] = key_states
self.cache[1, :, batch_idx, :, seq_idx, :] = value_states
def trim_cache(self, batch_idx: int, new_length: int):
"""裁剪KV Cache"""
self.cache[:, :, batch_idx, :, new_length:, :] = 0
def get_memory_usage(self) -> dict:
"""获取显存使用情况"""
total = self.cache.numel() * 2 # FP16
used = (self.cache != 0).sum().item() * 2
return {
'total_bytes': total,
'used_bytes': used,
'utilization': used / total if total > 0 else 0
}
训练显存优化
梯度检查点
from torch.utils.checkpoint import checkpoint
class GradientCheckpointing:
def __init__(self, model):
self.model = model
self.checkpoint_interval = 4 # 每4层做一次检查点
def apply_checkpointing(self):
"""应用梯度检查点"""
for i, layer in enumerate(self.model.layers):
if i % self.checkpoint_interval == 0:
self.model.layers[i] = self.wrap_with_checkpoint(layer)
def wrap_with_checkpoint(self, layer):
"""用检查点包装层"""
def wrapped(*args, **kwargs):
return checkpoint(layer, *args, use_reentrant=False, **kwargs)
return wrapped
def estimate_savings(self, original_memory: float) -> dict:
"""估算节省的显存"""
# 梯度检查点可以节省约60-70%的激活值显存
activation_savings = original_memory * 0.65
# 但会增加约30%的计算时间
compute_overhead = 0.3
return {
'memory_saved_gb': activation_savings,
'compute_overhead': compute_overhead
}
混合精度训练
class MixedPrecisionTrainer:
def __init__(self, model, optimizer):
self.model = model
self.optimizer = optimizer
self.scaler = torch.cuda.amp.GradScaler()
def train_step(self, batch):
"""混合精度训练步骤"""
self.optimizer.zero_grad()
# 自动混合精度
with torch.cuda.amp.autocast():
outputs = self.model(batch['input_ids'])
loss = self.compute_loss(outputs, batch['labels'])
# 缩放梯度
self.scaler.scale(loss).backward()
# 更新参数
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
def estimate_memory_savings(self) -> dict:
"""估算显存节省"""
# FP16参数占用是FP32的一半
param_savings = 0.5
# 梯度也是FP16
gradient_savings = 0.5
# 优化器状态仍然是FP32
optimizer_overhead = 0.0
return {
'parameter_savings': f"{param_savings * 100}%",
'gradient_savings': f"{gradient_savings * 100}%",
'overall_savings': f"{(param_savings + gradient_savings) / 2 * 100}%"
}
显存分片(ZeRO)
class ZeROPartitioner:
"""ZeRO显存分片策略"""
def __init__(self, model, num_gpus: int):
self.model = model
self.num_gpus = num_gpus
def partition_parameters(self) -> dict:
"""分区模型参数"""
total_params = sum(p.numel() for p in self.model.parameters())
params_per_gpu = total_params // self.num_gpus
return {
'total_parameters': total_params,
'parameters_per_gpu': params_per_gpu,
'memory_per_gpu_gb': params_per_gpu * 2 / 1024**3 # FP16
}
def get_zeRO_stages(self) -> dict:
"""获取ZeRO各阶段的显存使用"""
total_params = sum(p.numel() for p in self.model.parameters())
stages = {
'stage_0': {
'name': '无分片',
'params_per_gpu': total_params,
'optimizer_per_gpu': total_params * 8 # FP32 + momentum + variance
},
'stage_1': {
'name': '优化器状态分片',
'params_per_gpu': total_params,
'optimizer_per_gpu': total_params * 8 // self.num_gpus
},
'stage_2': {
'name': '优化器+梯度分片',
'params_per_gpu': total_params,
'optimizer_per_gpu': total_params * 8 // self.num_gpus,
'gradients_per_gpu': total_params * 2 // self.num_gpus
},
'stage_3': {
'name': '全部分片',
'params_per_gpu': total_params // self.num_gpus,
'optimizer_per_gpu': total_params * 8 // self.num_gpus,
'gradients_per_gpu': total_params * 2 // self.num_gpus
}
}
return stages
显存监控工具
class MemoryMonitor:
def __init__(self):
self.snapshots = []
def take_snapshot(self, label: str = ""):
"""拍摄显存快照"""
torch.cuda.synchronize()
snapshot = {
'label': label,
'timestamp': time.time(),
'allocated': torch.cuda.memory_allocated() / 1024**3,
'reserved': torch.cuda.memory_reserved() / 1024**3,
'max_allocated': torch.cuda.max_memory_allocated() / 1024**3
}
self.snapshots.append(snapshot)
return snapshot
def compare_snapshots(self, idx1: int, idx2: int) -> dict:
"""比较两个快照"""
s1 = self.snapshots[idx1]
s2 = self.snapshots[idx2]
return {
'allocated_diff': s2['allocated'] - s1['allocated'],
'reserved_diff': s2['reserved'] - s1['reserved'],
'time_diff': s2['timestamp'] - s1['timestamp']
}
def detect_leaks(self) -> list:
"""检测显存泄漏"""
leaks = []
for i in range(1, len(self.snapshots)):
diff = self.compare_snapshots(i-1, i)
if diff['allocated_diff'] > 0.1: # 超过100MB增长
leaks.append({
'between': i-1,
'growth_gb': diff['allocated_diff'],
'possible_cause': '显存泄漏'
})
return leaks
最佳实践
- 提前规划:计算好显存需求再开始训练
- 渐进加载:大模型使用
device_map="auto"自动分配
- 及时释放:不再需要的张量及时删除
- 批量分析:使用
nvidia-smi和PyTorch工具监控
- 分片训练:多GPU使用ZeRO分片
- 量化推理:推理时使用INT8/INT4量化