GAN基础入门
GAN基础入门
什么是GAN
生成对抗网络(GAN)由生成器和判别器组成,通过对抗训练生成逼真的数据。
GAN基本结构
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
latent_dim = 100
img_shape = (1, 28, 28)
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)
DCGAN架构
DCGAN使用卷积层提升生成质量:
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim, channels=1):
super(DCGANGenerator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
class DCGANDiscriminator(nn.Module):
def __init__(self, channels=1):
super(DCGANDiscriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(channels, 128, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1)
dcgan_G = DCGANGenerator(latent_dim)
dcgan_D = DCGANDiscriminator()
GAN训练过程
def train_gan(generator, discriminator, dataloader, epochs=100):
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002,
betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002,
betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss()
for epoch in range(epochs):
for i, (imgs, _) in enumerate(dataloader):
batch_size = imgs.size(0)
valid = torch.ones(batch_size, 1)
fake = torch.zeros(batch_size, 1)
z = torch.randn(batch_size, latent_dim)
gen_imgs = generator(z)
real_loss = adversarial_loss(discriminator(imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
print("训练流程定义完成")
损失函数
def compute_gradient_penalty(discriminator, real_imgs, fake_imgs):
batch_size = real_imgs.size(0)
alpha = torch.rand(batch_size, 1, 1, 1)
alpha = alpha.expand_as(real_imgs)
interpolated = (alpha * real_imgs + (1 - alpha) * fake_imgs).requires_grad_(True)
d_interpolated = discriminator(interpolated)
gradients = torch.autograd.grad(
outputs=d_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(d_interpolated),
create_graph=True,
retain_graph=True
)[0]
gradients = gradients.view(batch_size, -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
print("梯度惩罚计算完成")
生成样本
def generate_samples(generator, num_samples, latent_dim):
z = torch.randn(num_samples, latent_dim)
with torch.no_grad():
generated = generator(z)
return generated
generated = generate_samples(generator, 16, latent_dim)
print("生成样本形状:", generated.shape)
总结
GAN是生成模型的重要突破。通过生成器和判别器的对抗训练,可以生成逼真的图像、文本等数据。DCGAN是GAN的经典变体,为后续研究奠定了基础。