← 返回首页
🤖

GAN基础入门

📂 ai ⏱ 2 min 355 words

GAN基础入门

什么是GAN

生成对抗网络(GAN)由生成器和判别器组成,通过对抗训练生成逼真的数据。

GAN基本结构

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

latent_dim = 100
img_shape = (1, 28, 28)
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)

DCGAN架构

DCGAN使用卷积层提升生成质量:

class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim, channels=1):
        super(DCGANGenerator, self).__init__()
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.main(x)

class DCGANDiscriminator(nn.Module):
    def __init__(self, channels=1):
        super(DCGANDiscriminator, self).__init__()
        
        self.main = nn.Sequential(
            nn.Conv2d(channels, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x).view(-1)

dcgan_G = DCGANGenerator(latent_dim)
dcgan_D = DCGANDiscriminator()

GAN训练过程

def train_gan(generator, discriminator, dataloader, epochs=100):
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, 
                                   betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002,
                                   betas=(0.5, 0.999))
    
    adversarial_loss = nn.BCELoss()
    
    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            batch_size = imgs.size(0)
            
            valid = torch.ones(batch_size, 1)
            fake = torch.zeros(batch_size, 1)
            
            z = torch.randn(batch_size, latent_dim)
            gen_imgs = generator(z)
            
            real_loss = adversarial_loss(discriminator(imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            
            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
            
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

print("训练流程定义完成")

损失函数

def compute_gradient_penalty(discriminator, real_imgs, fake_imgs):
    batch_size = real_imgs.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1)
    alpha = alpha.expand_as(real_imgs)
    
    interpolated = (alpha * real_imgs + (1 - alpha) * fake_imgs).requires_grad_(True)
    d_interpolated = discriminator(interpolated)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

print("梯度惩罚计算完成")

生成样本

def generate_samples(generator, num_samples, latent_dim):
    z = torch.randn(num_samples, latent_dim)
    with torch.no_grad():
        generated = generator(z)
    return generated

generated = generate_samples(generator, 16, latent_dim)
print("生成样本形状:", generated.shape)

总结

GAN是生成模型的重要突破。通过生成器和判别器的对抗训练,可以生成逼真的图像、文本等数据。DCGAN是GAN的经典变体,为后续研究奠定了基础。