扩散模型
--- title: "扩散模型" description: "深入理解扩散模型的数学原理,包括DDPM、得分匹配和去噪过程的实现" tags: ["扩散模型", "DDPM", "得分匹配", "去噪"] category: "llm" icon: "🧠"
扩散模型
扩散模型简介
扩散模型(Diffusion Models)是一类基于概率的生成模型,通过逐步向数据添加噪声(前向过程),然后学习去除噪声(反向过程)来生成数据。近年来,扩散模型在图像生成、视频生成等领域取得了突破性进展。
数学原理
1. 前向过程(加噪)
前向过程是一个马尔可夫链,逐步向数据添加高斯噪声:
import torch
import torch.nn as nn
import numpy as np
class ForwardProcess:
def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
self.num_timesteps = num_timesteps
# 线性调度的噪声水平
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
# 计算累积参数
self.alphas = 1.0 - self.betas
self.alpha_bar = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
self.sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - self.alpha_bar)
def add_noise(self, x_0, t, noise=None):
"""
给x_0添加t时刻的噪声
Args:
x_0: 原始图像 [B, C, H, W]
t: 时间步 [B]
noise: 可选的指定噪声
"""
if noise is None:
noise = torch.randn_like(x_0)
# q(x_t | x_0) 的重参数化
sqrt_alpha_bar_t = self.sqrt_alpha_bar[t].reshape(-1, 1, 1, 1)
sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_bar[t].reshape(-1, 1, 1, 1)
x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
return x_t
def sample_timesteps(self, batch_size):
"""采样时间步"""
return torch.randint(0, self.num_timesteps, (batch_size,))
2. 反向过程(去噪)
反向过程学习逐步去除噪声:
class ReverseProcess:
def __init__(self, model, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
self.model = model
self.num_timesteps = num_timesteps
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
self.alphas = 1.0 - self.betas
self.alpha_bar = torch.cumprod(self.alphas, dim=0)
self.alpha_bar_prev = torch.cat([torch.tensor([1.0]), self.alpha_bar[:-1]])
# 计算反向过程参数
self.posterior_variance = (
self.betas * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar)
)
self.posterior_log_variance = torch.log(
torch.clamp(self.posterior_variance, min=1e-20)
)
self.posterior_mean_coef1 = (
self.betas * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar)
)
self.posterior_mean_coef2 = (
(1.0 - self.alpha_bar_prev) * torch.sqrt(self.alphas) / (1.0 - self.alpha_bar)
)
def predict_x0(self, x_t, t, noise_pred):
"""从x_t和预测的噪声预测x_0"""
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
return (
x_t - torch.sqrt(1 - alpha_bar_t) * noise_pred
) / torch.sqrt(alpha_bar_t)
def q_posterior_mean_variance(self, x_0, x_t, t):
"""计算后验分布的均值和方差"""
posterior_mean = (
self.posterior_mean_coef1[t].reshape(-1, 1, 1, 1) * x_0 +
self.posterior_mean_coef2[t].reshape(-1, 1, 1, 1) * x_t
)
posterior_var = self.posterior_variance[t].reshape(-1, 1, 1, 1)
posterior_log_var = self.posterior_log_variance[t].reshape(-1, 1, 1, 1)
return posterior_mean, posterior_var, posterior_log_var
def p_sample(self, x_t, t):
"""采样x_{t-1}"""
# 预测噪声
with torch.no_grad():
noise_pred = self.model(x_t, t)
# 预测x_0
x_0_pred = self.predict_x0(x_t, t, noise_pred)
x_0_pred.clamp_(-1, 1)
# 计算后验参数
posterior_mean, posterior_var, posterior_log_var = self.q_posterior_mean_variance(
x_0_pred, x_t, t
)
# 采样
noise = torch.randn_like(x_t)
x_prev = posterior_mean + torch.exp(0.5 * posterior_log_var) * noise
return x_prev
@torch.no_grad()
def sample(self, shape):
"""完整的采样过程"""
x = torch.randn(shape)
for t in reversed(range(self.num_timesteps)):
t_batch = torch.full((shape[0],), t, dtype=torch.long, device=x.device)
x = self.p_sample(x, t_batch)
return x
DDPM(去噪扩散概率模型)
class DDPM(nn.Module):
def __init__(self, unet, num_timesteps=1000):
super().__init__()
self.unet = unet
self.num_timesteps = num_timesteps
# 注册缓冲区
betas = torch.linspace(1e-4, 0.02, num_timesteps)
alphas = 1.0 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
self.register_buffer('betas', betas)
self.register_buffer('alphas', alphas)
self.register_buffer('alpha_bar', alpha_bar)
def forward(self, x_0):
"""训练前向传播"""
batch_size = x_0.shape[0]
# 采样时间步
t = torch.randint(0, self.num_timesteps, (batch_size,), device=x_0.device)
# 采样噪声
noise = torch.randn_like(x_0)
# 加噪
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
# 预测噪声
noise_pred = self.unet(x_t, t)
# 计算MSE损失
loss = nn.functional.mse_loss(noise_pred, noise)
return loss
得分匹配
class ScoreMatching:
"""得分匹配:学习数据分布的梯度"""
def __init__(self, score_model, sigma_min=0.01, sigma_max=1.0):
self.score_model = score_model
self.sigmas = torch.exp(
torch.linspace(np.log(sigma_min), np.log(sigma_max), 10)
)
def compute_score(self, x, sigma):
"""计算得分函数 ∇_x log p(x)"""
# 通过噪声扰动估计得分
x_noisy = x + sigma * torch.randn_like(x)
score = self.score_model(x_noisy, sigma)
return score
def loss_function(self, x):
"""计算得分匹配损失"""
# 随机选择噪声水平
sigma = self.sigmas[torch.randint(len(self.sigmas), (1,))].item()
# 添加噪声
x_noisy = x + sigma * torch.randn_like(x)
# 预测得分
predicted_score = self.score_model(x_noisy, sigma)
# 目标:真实得分
target_score = -(x_noisy - x) / (sigma ** 2)
# 计算损失
loss = 0.5 * ((predicted_score - target_score) ** 2).mean()
return loss
采样加速技术
class DDIMSampler:
"""DDIM采样:加速采样过程"""
def __init__(self, model, num_inference_steps=50):
self.model = model
self.num_inference_steps = num_inference_steps
# 创建子序列
self.timesteps = np.linspace(
0, model.num_timesteps - 1, num_inference_steps
).astype(int)
@torch.no_grad()
def sample(self, shape):
"""DDIM采样"""
x = torch.randn(shape)
for i, t in enumerate(reversed(self.timesteps)):
t_batch = torch.full((shape[0],), t, dtype=torch.long, device=x.device)
# 预测噪声
noise_pred = self.model.unet(x, t_batch)
# DDIM更新
alpha_bar_t = self.model.alpha_bar[t]
alpha_bar_t_prev = self.model.alpha_bar[self.timesteps[i]] if i < len(self.timesteps) - 1 else torch.tensor(1.0)
x0_pred = (x - torch.sqrt(1 - alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
sigma_t = torch.sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t)) * torch.sqrt(1 - alpha_bar_t / alpha_bar_t_prev)
x = torch.sqrt(alpha_bar_t_prev) * x0_pred + sigma_t * torch.randn_like(x)
return x
总结
扩散模型通过学习去噪过程实现了强大的生成能力。理解DDPM、得分匹配等核心概念,对于掌握现代图像生成技术至关重要。