← 返回首页
🧠

扩散模型

📂 llm ⏱ 3 min 557 words

--- title: "扩散模型" description: "深入理解扩散模型的数学原理,包括DDPM、得分匹配和去噪过程的实现" tags: ["扩散模型", "DDPM", "得分匹配", "去噪"] category: "llm" icon: "🧠"

扩散模型

扩散模型简介

扩散模型(Diffusion Models)是一类基于概率的生成模型,通过逐步向数据添加噪声(前向过程),然后学习去除噪声(反向过程)来生成数据。近年来,扩散模型在图像生成、视频生成等领域取得了突破性进展。

数学原理

1. 前向过程(加噪)

前向过程是一个马尔可夫链,逐步向数据添加高斯噪声:

import torch
import torch.nn as nn
import numpy as np

class ForwardProcess:
    def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
        self.num_timesteps = num_timesteps
        
        # 线性调度的噪声水平
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        
        # 计算累积参数
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - self.alpha_bar)
    
    def add_noise(self, x_0, t, noise=None):
        """
        给x_0添加t时刻的噪声
        
        Args:
            x_0: 原始图像 [B, C, H, W]
            t: 时间步 [B]
            noise: 可选的指定噪声
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        # q(x_t | x_0) 的重参数化
        sqrt_alpha_bar_t = self.sqrt_alpha_bar[t].reshape(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_bar[t].reshape(-1, 1, 1, 1)
        
        x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
        return x_t
    
    def sample_timesteps(self, batch_size):
        """采样时间步"""
        return torch.randint(0, self.num_timesteps, (batch_size,))

2. 反向过程(去噪)

反向过程学习逐步去除噪声:

class ReverseProcess:
    def __init__(self, model, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
        self.model = model
        self.num_timesteps = num_timesteps
        
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)
        self.alpha_bar_prev = torch.cat([torch.tensor([1.0]), self.alpha_bar[:-1]])
        
        # 计算反向过程参数
        self.posterior_variance = (
            self.betas * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar)
        )
        self.posterior_log_variance = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )
        self.posterior_mean_coef1 = (
            self.betas * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alphas) / (1.0 - self.alpha_bar)
        )
    
    def predict_x0(self, x_t, t, noise_pred):
        """从x_t和预测的噪声预测x_0"""
        alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
        return (
            x_t - torch.sqrt(1 - alpha_bar_t) * noise_pred
        ) / torch.sqrt(alpha_bar_t)
    
    def q_posterior_mean_variance(self, x_0, x_t, t):
        """计算后验分布的均值和方差"""
        posterior_mean = (
            self.posterior_mean_coef1[t].reshape(-1, 1, 1, 1) * x_0 +
            self.posterior_mean_coef2[t].reshape(-1, 1, 1, 1) * x_t
        )
        posterior_var = self.posterior_variance[t].reshape(-1, 1, 1, 1)
        posterior_log_var = self.posterior_log_variance[t].reshape(-1, 1, 1, 1)
        
        return posterior_mean, posterior_var, posterior_log_var
    
    def p_sample(self, x_t, t):
        """采样x_{t-1}"""
        # 预测噪声
        with torch.no_grad():
            noise_pred = self.model(x_t, t)
        
        # 预测x_0
        x_0_pred = self.predict_x0(x_t, t, noise_pred)
        x_0_pred.clamp_(-1, 1)
        
        # 计算后验参数
        posterior_mean, posterior_var, posterior_log_var = self.q_posterior_mean_variance(
            x_0_pred, x_t, t
        )
        
        # 采样
        noise = torch.randn_like(x_t)
        x_prev = posterior_mean + torch.exp(0.5 * posterior_log_var) * noise
        
        return x_prev
    
    @torch.no_grad()
    def sample(self, shape):
        """完整的采样过程"""
        x = torch.randn(shape)
        
        for t in reversed(range(self.num_timesteps)):
            t_batch = torch.full((shape[0],), t, dtype=torch.long, device=x.device)
            x = self.p_sample(x, t_batch)
        
        return x

DDPM(去噪扩散概率模型)

class DDPM(nn.Module):
    def __init__(self, unet, num_timesteps=1000):
        super().__init__()
        self.unet = unet
        self.num_timesteps = num_timesteps
        
        # 注册缓冲区
        betas = torch.linspace(1e-4, 0.02, num_timesteps)
        alphas = 1.0 - betas
        alpha_bar = torch.cumprod(alphas, dim=0)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alpha_bar', alpha_bar)
    
    def forward(self, x_0):
        """训练前向传播"""
        batch_size = x_0.shape[0]
        
        # 采样时间步
        t = torch.randint(0, self.num_timesteps, (batch_size,), device=x_0.device)
        
        # 采样噪声
        noise = torch.randn_like(x_0)
        
        # 加噪
        alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
        
        # 预测噪声
        noise_pred = self.unet(x_t, t)
        
        # 计算MSE损失
        loss = nn.functional.mse_loss(noise_pred, noise)
        
        return loss

得分匹配

class ScoreMatching:
    """得分匹配:学习数据分布的梯度"""
    
    def __init__(self, score_model, sigma_min=0.01, sigma_max=1.0):
        self.score_model = score_model
        self.sigmas = torch.exp(
            torch.linspace(np.log(sigma_min), np.log(sigma_max), 10)
        )
    
    def compute_score(self, x, sigma):
        """计算得分函数 ∇_x log p(x)"""
        # 通过噪声扰动估计得分
        x_noisy = x + sigma * torch.randn_like(x)
        score = self.score_model(x_noisy, sigma)
        return score
    
    def loss_function(self, x):
        """计算得分匹配损失"""
        # 随机选择噪声水平
        sigma = self.sigmas[torch.randint(len(self.sigmas), (1,))].item()
        
        # 添加噪声
        x_noisy = x + sigma * torch.randn_like(x)
        
        # 预测得分
        predicted_score = self.score_model(x_noisy, sigma)
        
        # 目标:真实得分
        target_score = -(x_noisy - x) / (sigma ** 2)
        
        # 计算损失
        loss = 0.5 * ((predicted_score - target_score) ** 2).mean()
        
        return loss

采样加速技术

class DDIMSampler:
    """DDIM采样:加速采样过程"""
    
    def __init__(self, model, num_inference_steps=50):
        self.model = model
        self.num_inference_steps = num_inference_steps
        
        # 创建子序列
        self.timesteps = np.linspace(
            0, model.num_timesteps - 1, num_inference_steps
        ).astype(int)
    
    @torch.no_grad()
    def sample(self, shape):
        """DDIM采样"""
        x = torch.randn(shape)
        
        for i, t in enumerate(reversed(self.timesteps)):
            t_batch = torch.full((shape[0],), t, dtype=torch.long, device=x.device)
            
            # 预测噪声
            noise_pred = self.model.unet(x, t_batch)
            
            # DDIM更新
            alpha_bar_t = self.model.alpha_bar[t]
            alpha_bar_t_prev = self.model.alpha_bar[self.timesteps[i]] if i < len(self.timesteps) - 1 else torch.tensor(1.0)
            
            x0_pred = (x - torch.sqrt(1 - alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
            sigma_t = torch.sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t)) * torch.sqrt(1 - alpha_bar_t / alpha_bar_t_prev)
            
            x = torch.sqrt(alpha_bar_t_prev) * x0_pred + sigma_t * torch.randn_like(x)
        
        return x

总结

扩散模型通过学习去噪过程实现了强大的生成能力。理解DDPM、得分匹配等核心概念,对于掌握现代图像生成技术至关重要。