← 返回首页
🤖

自监督学习:无标签数据的表示学习

📂 ai ⏱ 3 min 441 words

自监督学习:无标签数据的表示学习

什么是自监督学习?

自监督学习是利用无标签数据本身创造监督信号的学习方法。它通过设计 pretext task(前置任务),让模型从数据中学习有意义的表示。

自监督学习vs监督学习vs无监督学习

方法 数据 监督信号 目标
监督学习 有标签 人工标注 预测标签
无监督学习 无标签 发现结构
自监督学习 无标签 数据自身 学习表示

自监督学习的主要方法

1. 对比学习

对比学习的核心思想是:让相似的样本在表示空间中靠近,不相似的样本远离。

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimCLR(nn.Module):
    def __init__(self, backbone, projection_dim=128):
        super().__init__()
        self.backbone = backbone
        self.projector = nn.Sequential(
            nn.Linear(backbone.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )
    
    def forward(self, x1, x2):
        # 获取表示
        h1 = self.backbone(x1)
        h2 = self.backbone(x2)
        
        # 投影到对比空间
        z1 = self.projector(h1)
        z2 = self.projector(h2)
        
        return z1, z2

def nt_xent_loss(z1, z2, temperature=0.5):
    batch_size = z1.size(0)
    
    # 归一化
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    # 拼接
    z = torch.cat([z1, z2], dim=0)
    
    # 计算相似度矩阵
    sim = torch.mm(z, z.t()) / temperature
    
    # 创建标签(对角线为正样本)
    labels = torch.cat([
        torch.arange(batch_size, 2 * batch_size),
        torch.arange(batch_size)
    ])
    
    # 移除对角线
    mask = torch.eye(2 * batch_size, dtype=torch.bool)
    sim = sim.masked_fill(mask, -1e9)
    
    # 交叉熵损失
    loss = F.cross_entropy(sim, labels)
    
    return loss

2. 掩码语言模型(MLM)

BERT使用的预训练方法,随机掩码部分token,让模型预测被掩码的内容。

class MaskedLanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(512, d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model, num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, input_ids, attention_mask=None):
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        
        x = self.embedding(input_ids) + self.position_embedding(positions)
        x = x.transpose(0, 1)
        x = self.transformer(x, src_key_padding_mask=~attention_mask)
        x = x.transpose(0, 1)
        
        return self.head(x)

def mlm_loss(model, input_ids, attention_mask, masked_positions, masked_labels):
    logits = model(input_ids, attention_mask)
    
    # 只计算被掩码位置的损失
    masked_logits = logits[:, masked_positions]
    loss = F.cross_entropy(masked_logits, masked_labels)
    
    return loss

3. 旋转预测

让模型预测图像旋转的角度(0°, 90°, 180°, 270°)。

class RotationPrediction(nn.Module):
    def __init__(self, backbone, num_classes=4):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.output_dim, num_classes)
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

def rotate_images(images, angles):
    rotated = []
    labels = []
    
    for angle in angles:
        if angle == 0:
            rotated.append(images)
            labels.append(0)
        elif angle == 90:
            rotated.append(torch.rot90(images, 1, [2, 3]))
            labels.append(1)
        elif angle == 180:
            rotated.append(torch.rot90(images, 2, [2, 3]))
            labels.append(2)
        elif angle == 270:
            rotated.append(torch.rot90(images, 3, [2, 3]))
            labels.append(3)
    
    return torch.cat(rotated), torch.tensor(labels)

4. 拼图预测

打乱图像patch的顺序,让模型预测正确的排列。

class JigsawPuzzle(nn.Module):
    def __init__(self, backbone, num_permutations=100):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.output_dim, num_permutations)
    
    def forward(self, patches):
        # patches: (batch, num_patches, channels, height, width)
        batch_size, num_patches = patches.size(0), patches.size(1)
        
        # 处理每个patch
        features = []
        for i in range(num_patches):
            feat = self.backbone(patches[:, i])
            features.append(feat)
        
        # 拼接特征
        features = torch.cat(features, dim=1)
        
        return self.classifier(features)

def generate_jigsaw(patches, num_permutations=100):
    num_patches = patches.size(1)
    
    # 生成随机排列
    permutations = torch.randperm(num_patches)
    
    # 打乱patches
    shuffled_patches = patches[:, permutations]
    
    return shuffled_patches, permutations

5. 未来帧预测

在视频中,让模型预测未来的帧。

class FutureFramePrediction(nn.Module):
    def __init__(self, backbone, hidden_dim):
        super().__init__()
        self.backbone = backbone
        self.decoder = nn.Sequential(
            nn.Linear(backbone.output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 3 * 64 * 64)  # 预测下一帧
        )
    
    def forward(self, video_frames):
        # 使用前面的帧预测最后一帧
        features = self.backbone(video_frames[:, :-1])
        features = features.mean(dim=1)
        
        predicted_frame = self.decoder(features)
        return predicted_frame.view(-1, 3, 64, 64)

自监督学习的应用

1. NLP预训练

2. 视觉预训练

3. 多模态预训练

自监督学习的优势

  1. 减少对标注数据的依赖:利用大量无标签数据
  2. 学习通用表示:预训练模型可以迁移到多种下游任务
  3. 提高数据效率:在小数据集上也能取得好效果

挑战与未来

  1. 预训练任务设计:如何设计更好的 pretext task
  2. 计算效率:大规模预训练需要大量计算资源
  3. 理论理解:为什么自监督学习有效

总结

自监督学习通过利用数据自身的结构创造监督信号,在无标签数据上学习有意义的表示。从对比学习到掩码语言模型,自监督学习已成为现代AI系统的核心技术,为大规模预训练模型的成功奠定了基础。