← 返回首页
🧠

多GPU推理

📂 llm ⏱ 2 min 356 words

--- title: "多GPU推理" description: "全面讲解多GPU推理技术,涵盖张量并行、流水线并行和NVLink高速互联,实现大模型高效分布式推理。" tags: ["多GPU", "张量并行", "流水线并行", "NVLink"] category: "llm" icon: "🧠"

多GPU推理

为什么需要多GPU

单张GPU显存有限,无法容纳超大模型。即使模型能放入单卡,多GPU也能通过并行计算大幅提升吞吐量。选择哪种并行策略取决于模型规模、GPU互联带宽和延迟要求。

并行策略概览

张量并行(Tensor Parallelism)

将单个层的权重矩阵切分到多张GPU上,每张卡计算部分结果:

import torch
import torch.distributed as dist

def tensor_parallel_linear(x, weight, world_size, rank):
    """张量并行的矩阵切分示例"""
    # 将权重沿输出维度切分
    shard_size = weight.shape[0] // world_size
    local_weight = weight[rank * shard_size:(rank + 1) * shard_size]
    
    # 每张卡计算部分输出
    local_output = torch.mm(x, local_weight.T)
    
    # AllReduce聚合所有卡的结果
    dist.all_reduce(local_output, op=dist.ReduceOp.SUM)
    return local_output

流水线并行(Pipeline Parallelism)

将模型按层切分,不同层放在不同GPU上:

class PipelineStage(torch.nn.Module):
    def __init__(self, layers, stage_id, num_stages):
        super().__init__()
        self.layers = layers
        self.stage_id = stage_id
        self.num_stages = num_stages
    
    def forward(self, x):
        # 接收来自前一个stage的输入
        if self.stage_id > 0:
            dist.recv(x, src=self.stage_id - 1)
        
        for layer in self.layers:
            x = layer(x)
        
        # 发送给下一个stage
        if self.stage_id < self.num_stages - 1:
            dist.send(x, dst=self.stage_id + 1)
        
        return x

序列并行(Sequence Parallelism)

对长序列进行切分,处理超长上下文:

def sequence_parallel_attention(q, k, v, seq_world_size, rank):
    """序列并行:在序列维度切分注意力计算"""
    seq_len = q.shape[2]
    local_seq_len = seq_len // seq_world_size
    
    # 切分QKV
    local_q = q[:, :, rank*local_seq_len:(rank+1)*local_seq_len]
    
    # 本地计算部分注意力
    scores = torch.matmul(local_q, k.transpose(-2, -1))
    attn = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn, v)
    
    # AllGather聚合完整输出
    gathered = [torch.empty_like(output) for _ in range(seq_world_size)]
    dist.all_gather(gathered, output)
    return torch.cat(gathered, dim=2)

使用vLLM多GPU推理

from vllm import LLM, SamplingParams

# 自动张量并行,指定GPU数量
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    tensor_parallel_size=4,  # 使用4张GPU
    max_model_len=4096,
    gpu_memory_utilization=0.9,
)

# 批量推理
prompts = ["解释量子计算", "Python异步编程", "机器学习基础"]
params = SamplingParams(temperature=0.7, max_tokens=512)
outputs = llm.generate(prompts, params)

for output in outputs:
    print(output.outputs[0].text)

使用DeepSpeed推理

import deepspeed
import torch

# DeepSpeed推理配置
ds_config = {
    "tensor_parallel": {"tp_size": 2},
    "dtype": "fp16",
    "enable_cuda_graph": True,
    "replace_with_kernel_inject": True,
}

model = deepspeed.init_inference(
    model,
    config=ds_config,
    mp_size=2,  # 张量并行大小
)

# 推理
input_ids = tokenizer.encode("你好,请介绍一下自己", return_tensors="pt")
output = model.generate(input_ids, max_new_tokens=256)

GPU互联与通信

NVLink vs PCIe

def check_gpu_topology():
    """检查GPU互联拓扑"""
    import subprocess
    result = subprocess.run(
        ['nvidia-smi', 'topo', '-m'],
        capture_output=True, text=True
    )
    print(result.stdout)

def estimate_bandwidth():
    """带宽对比"""
    pcie_bw = 32  # PCIe 4.0 x16 约32GB/s
    nvlink_bw = 600  # NVLink 4.0 约600GB/s
    
    print(f"PCIe带宽: {pcie_bw}GB/s")
    print(f"NVLink带宽: {nvlink_bw}GB/s")
    print(f"NVLink是PCIe的 {nvlink_bw/pcie_bw:.1f} 倍")

estimate_bandwidth()

通信优化

def optimize_communication():
    """使用梯度压缩减少通信量"""
    from torch.distributed.algorithms.ddp_comm_hooks import (
        default_hooks as default,
        powerSGD_hook as powerSGD,
    )
    
    # 压缩通信数据
    model.register_comm_hook(
        state=None,
        hook=powerSGD.powerSGD_hook
    )

性能调优

def profile_multi_gpu():
    """多GPU性能分析"""
    import time
    
    configs = [
        {"tp_size": 1, "batch_size": 1},
        {"tp_size": 2, "batch_size": 1},
        {"tp_size": 4, "batch_size": 1},
    ]
    
    for config in configs:
        start = time.time()
        # 运行推理...
        latency = time.time() - start
        throughput = config["batch_size"] / latency
        print(f"TP={config['tp_size']}: 延迟{latency:.3f}s, "
              f"吞吐{throughput:.1f} tokens/s")

选择并行策略

策略 适用场景 通信量 显存节省
张量并行 层内切分,高带宽互联 中等
流水线并行 层间切分,低延迟要求 均匀
序列并行 超长序列 中等 序列维度

最佳实践