自监督学习:无标签数据的表示学习
自监督学习:无标签数据的表示学习
什么是自监督学习?
自监督学习是利用无标签数据本身创造监督信号的学习方法。它通过设计 pretext task(前置任务),让模型从数据中学习有意义的表示。
自监督学习vs监督学习vs无监督学习
| 方法 | 数据 | 监督信号 | 目标 |
|---|---|---|---|
| 监督学习 | 有标签 | 人工标注 | 预测标签 |
| 无监督学习 | 无标签 | 无 | 发现结构 |
| 自监督学习 | 无标签 | 数据自身 | 学习表示 |
自监督学习的主要方法
1. 对比学习
对比学习的核心思想是:让相似的样本在表示空间中靠近,不相似的样本远离。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimCLR(nn.Module):
def __init__(self, backbone, projection_dim=128):
super().__init__()
self.backbone = backbone
self.projector = nn.Sequential(
nn.Linear(backbone.output_dim, 256),
nn.ReLU(),
nn.Linear(256, projection_dim)
)
def forward(self, x1, x2):
# 获取表示
h1 = self.backbone(x1)
h2 = self.backbone(x2)
# 投影到对比空间
z1 = self.projector(h1)
z2 = self.projector(h2)
return z1, z2
def nt_xent_loss(z1, z2, temperature=0.5):
batch_size = z1.size(0)
# 归一化
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# 拼接
z = torch.cat([z1, z2], dim=0)
# 计算相似度矩阵
sim = torch.mm(z, z.t()) / temperature
# 创建标签(对角线为正样本)
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
])
# 移除对角线
mask = torch.eye(2 * batch_size, dtype=torch.bool)
sim = sim.masked_fill(mask, -1e9)
# 交叉熵损失
loss = F.cross_entropy(sim, labels)
return loss
2. 掩码语言模型(MLM)
BERT使用的预训练方法,随机掩码部分token,让模型预测被掩码的内容。
class MaskedLanguageModel(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(512, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model, num_heads)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.head = nn.Linear(d_model, vocab_size)
def forward(self, input_ids, attention_mask=None):
seq_len = input_ids.size(1)
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
x = self.embedding(input_ids) + self.position_embedding(positions)
x = x.transpose(0, 1)
x = self.transformer(x, src_key_padding_mask=~attention_mask)
x = x.transpose(0, 1)
return self.head(x)
def mlm_loss(model, input_ids, attention_mask, masked_positions, masked_labels):
logits = model(input_ids, attention_mask)
# 只计算被掩码位置的损失
masked_logits = logits[:, masked_positions]
loss = F.cross_entropy(masked_logits, masked_labels)
return loss
3. 旋转预测
让模型预测图像旋转的角度(0°, 90°, 180°, 270°)。
class RotationPrediction(nn.Module):
def __init__(self, backbone, num_classes=4):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.output_dim, num_classes)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
def rotate_images(images, angles):
rotated = []
labels = []
for angle in angles:
if angle == 0:
rotated.append(images)
labels.append(0)
elif angle == 90:
rotated.append(torch.rot90(images, 1, [2, 3]))
labels.append(1)
elif angle == 180:
rotated.append(torch.rot90(images, 2, [2, 3]))
labels.append(2)
elif angle == 270:
rotated.append(torch.rot90(images, 3, [2, 3]))
labels.append(3)
return torch.cat(rotated), torch.tensor(labels)
4. 拼图预测
打乱图像patch的顺序,让模型预测正确的排列。
class JigsawPuzzle(nn.Module):
def __init__(self, backbone, num_permutations=100):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.output_dim, num_permutations)
def forward(self, patches):
# patches: (batch, num_patches, channels, height, width)
batch_size, num_patches = patches.size(0), patches.size(1)
# 处理每个patch
features = []
for i in range(num_patches):
feat = self.backbone(patches[:, i])
features.append(feat)
# 拼接特征
features = torch.cat(features, dim=1)
return self.classifier(features)
def generate_jigsaw(patches, num_permutations=100):
num_patches = patches.size(1)
# 生成随机排列
permutations = torch.randperm(num_patches)
# 打乱patches
shuffled_patches = patches[:, permutations]
return shuffled_patches, permutations
5. 未来帧预测
在视频中,让模型预测未来的帧。
class FutureFramePrediction(nn.Module):
def __init__(self, backbone, hidden_dim):
super().__init__()
self.backbone = backbone
self.decoder = nn.Sequential(
nn.Linear(backbone.output_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 3 * 64 * 64) # 预测下一帧
)
def forward(self, video_frames):
# 使用前面的帧预测最后一帧
features = self.backbone(video_frames[:, :-1])
features = features.mean(dim=1)
predicted_frame = self.decoder(features)
return predicted_frame.view(-1, 3, 64, 64)
自监督学习的应用
1. NLP预训练
- BERT:掩码语言模型 + 下一句预测
- GPT:自回归语言模型
- T5:文本到文本的预训练
2. 视觉预训练
- SimCLR:对比学习
- MAE:掩码自编码器
- DINO:自蒸馏
3. 多模态预训练
- CLIP:图像-文本对比学习
- ALIGN:大规模图像-文本对齐
自监督学习的优势
- 减少对标注数据的依赖:利用大量无标签数据
- 学习通用表示:预训练模型可以迁移到多种下游任务
- 提高数据效率:在小数据集上也能取得好效果
挑战与未来
- 预训练任务设计:如何设计更好的 pretext task
- 计算效率:大规模预训练需要大量计算资源
- 理论理解:为什么自监督学习有效
总结
自监督学习通过利用数据自身的结构创造监督信号,在无标签数据上学习有意义的表示。从对比学习到掩码语言模型,自监督学习已成为现代AI系统的核心技术,为大规模预训练模型的成功奠定了基础。