← 返回首页
🤖

图像分类实践详解

📂 ai ⏱ 4 min 747 words

图像分类实践详解

图像分类是计算机视觉中最基础的任务,通过深度学习模型将图像分配到预定义的类别中。

图像分类流程

数据准备

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models, transforms, datasets
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# 检查CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

# 创建模拟图像数据(实际项目中应使用真实数据集)
np.random.seed(42)
n_samples = 1000
n_classes = 10
img_size = 32

# 生成模拟图像
X = np.random.randn(n_samples, 3, img_size, img_size).astype(np.float32)
y = np.random.randint(0, n_classes, n_samples)

print(f"数据集大小: {X.shape}")
print(f"类别数量: {n_classes}")
print(f"图像大小: {img_size}x{img_size}")

数据预处理

# 数据预处理
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

# 数据增强
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

# 应用预处理
X_processed = np.array([transform(img).numpy() for img in X])

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
    X_processed, y, test_size=0.2, random_state=42
)

# 转换为PyTorch张量
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.LongTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.LongTensor(y_test)

# 创建数据加载器
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")

模型构建

简单CNN模型

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        
        self.features = nn.Sequential(
            # 卷积层1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # 卷积层2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # 卷积层3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# 创建模型
model = SimpleCNN(num_classes=n_classes).to(device)
print(f"模型参数: {sum(p.numel() for p in model.parameters()):,}")

使用预训练模型

# 使用预训练的ResNet
resnet = models.resnet18(pretrained=True)

# 修改最后一层
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, n_classes)

# 冻结除最后一层外的所有层
for param in resnet.parameters():
    param.requires_grad = False

# 只训练最后一层
for param in resnet.fc.parameters():
    param.requires_grad = True

resnet = resnet.to(device)

print(f"ResNet参数: {sum(p.numel() for p in resnet.parameters()):,}")
print(f"可训练参数: {sum(p.numel() for p in resnet.parameters() if p.requires_grad):,}")

模型训练

训练循环

def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):
    """训练模型"""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    train_losses = []
    train_accs = []
    test_accs = []
    
    for epoch in range(epochs):
        # 训练阶段
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_X, batch_y in train_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += batch_y.size(0)
            correct += (predicted == batch_y).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # 测试阶段
        test_acc = evaluate_model(model, test_loader)
        test_accs.append(test_acc)
        
        print(f'Epoch [{epoch+1}/{epochs}], '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Test Acc: {test_acc:.2f}%')
    
    return train_losses, train_accs, test_accs

def evaluate_model(model, test_loader):
    """评估模型"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)
            _, predicted = torch.max(outputs.data, 1)
            total += batch_y.size(0)
            correct += (predicted == batch_y).sum().item()
    
    return 100 * correct / total

# 训练简单CNN
print("训练简单CNN:")
train_losses, train_accs, test_accs = train_model(
    model, train_loader, test_loader, epochs=10, lr=0.001
)

训练预训练模型

# 训练预训练模型
print("\n训练预训练ResNet:")
train_losses_pt, train_accs_pt, test_accs_pt = train_model(
    resnet, train_loader, test_loader, epochs=10, lr=0.001
)

模型评估

性能可视化

# 可视化训练过程
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# 损失曲线
ax1.plot(train_losses, 'b-', label='简单CNN', linewidth=2)
ax1.plot(train_losses_pt, 'r-', label='预训练ResNet', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('损失')
ax1.set_title('训练损失曲线')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 准确率曲线
ax2.plot(test_accs, 'b-', label='简单CNN', linewidth=2)
ax2.plot(test_accs_pt, 'r-', label='预训练ResNet', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('准确率')
ax2.set_title('测试准确率曲线')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# 最终性能比较
print("\n最终性能比较:")
print(f"简单CNN: {test_accs[-1]:.2f}%")
print(f"预训练ResNet: {test_accs_pt[-1]:.2f}%")

混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(model, test_loader, class_names=None):
    """绘制混淆矩阵"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())
    
    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    
    # 可视化
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.title('混淆矩阵')
    plt.tight_layout()
    plt.show()
    
    # 计算每类准确率
    if class_names:
        print("\n每类准确率:")
        for i, name in enumerate(class_names):
            class_acc = cm[i, i] / cm[i].sum() * 100
            print(f"{name}: {class_acc:.2f}%")

# 绘制混淆矩阵(使用模拟类别名称)
class_names = [f'类别{i}' for i in range(n_classes)]
plot_confusion_matrix(model, test_loader, class_names)

模型部署

保存和加载模型

# 保存模型
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optim.Adam(model.parameters()).state_dict(),
    'num_classes': n_classes,
    'img_size': img_size,
}, 'image_classifier.pth')

print("模型已保存")

# 加载模型
checkpoint = torch.load('image_classifier.pth')
loaded_model = SimpleCNN(num_classes=checkpoint['num_classes'])
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()

print("模型已加载")

# 验证加载的模型
test_acc_loaded = evaluate_model(loaded_model, test_loader)
print(f"加载模型的准确率: {test_acc_loaded:.2f}%")

预测新图像

def predict_image(model, image, class_names=None):
    """预测单张图像"""
    model.eval()
    
    # 预处理
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # 转换为tensor
    if isinstance(image, np.ndarray):
        image = transform(image).unsqueeze(0)
    
    # 预测
    with torch.no_grad():
        outputs = model(image.to(device))
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
    
    # 返回结果
    result = {
        'predicted_class': predicted.item(),
        'confidence': confidence.item(),
        'probabilities': probabilities[0].cpu().numpy()
    }
    
    if class_names:
        result['class_name'] = class_names[predicted.item()]
    
    return result

# 测试预测
sample_image = X_test[0]
result = predict_image(model, sample_image, class_names)

print("\n预测结果:")
print(f"预测类别: {result['class_name']}")
print(f"置信度: {result['confidence']:.2%}")

# 可视化预测概率
plt.figure(figsize=(10, 6))
plt.bar(class_names, result['probabilities'], color='skyblue', edgecolor='black')
plt.xlabel('类别')
plt.ylabel('概率')
plt.title('预测概率分布')
plt.xticks(rotation=45)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

图像分类最佳实践

  1. 数据增强:增加训练数据多样性
  2. 迁移学习:使用预训练模型提高性能
  3. 正则化:使用Dropout、数据增强防止过拟合
  4. 学习率调度:动态调整学习率
  5. 集成学习:结合多个模型提高准确率

图像分类是计算机视觉的基础,掌握图像分类技术对于目标检测、语义分割等高级任务至关重要。