← 返回首页
🤖

过拟合与欠拟合:诊断与解决

📂 ai ⏱ 2 min 400 words

过拟合与欠拟合:诊断与解决

过拟合和欠拟合是机器学习中最常见的问题,理解它们对构建高性能模型至关重要。

偏差-方差权衡

模型误差可分解为偏差、方差和噪声三部分。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split, learning_curve
from sklearn.metrics import mean_squared_error

# 生成带噪声的数据
np.random.seed(42)
n_samples = 100
X = np.sort(np.random.uniform(0, 10, n_samples)).reshape(-1, 1)
y_true = np.sin(X).ravel()
y = y_true + np.random.normal(0, 0.3, n_samples)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

print(f"训练集大小: {len(X_train)}")
print(f"测试集大小: {len(X_test)}")

欠拟合示例

欠拟合模型过于简单,无法捕捉数据中的模式。

# 欠拟合:线性模型拟合非线性数据
linear = LinearRegression()
linear.fit(X_train, y_train)

train_score = mean_squared_error(y_train, linear.predict(X_train))
test_score = mean_squared_error(y_test, linear.predict(X_test))

print("欠拟合(线性模型):")
print(f"训练集MSE: {train_score:.4f}")
print(f"测试集MSE: {test_score:.4f}")

# 可视化
plt.figure(figsize=(10, 6))
plt.scatter(X_train, y_train, label='训练数据', alpha=0.6)
plt.scatter(X_test, y_test, label='测试数据', alpha=0.6)
X_plot = np.linspace(0, 10, 100).reshape(-1, 1)
plt.plot(X_plot, linear.predict(X_plot), 'r-', label='线性模型', linewidth=2)
plt.plot(X_plot, np.sin(X_plot), 'g--', label='真实函数', linewidth=2)
plt.xlabel('X')
plt.ylabel('y')
plt.title('欠拟合示例')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

过拟合示例

过拟合模型过于复杂,学习了训练数据中的噪声。

# 过拟合:高次多项式
high_degree = Pipeline([
    ('poly', PolynomialFeatures(degree=15)),
    ('linear', LinearRegression())
])
high_degree.fit(X_train, y_train)

train_score = mean_squared_error(y_train, high_degree.predict(X_train))
test_score = mean_squared_error(y_test, high_degree.predict(X_test))

print("\n过拟合(15次多项式):")
print(f"训练集MSE: {train_score:.4f}")
print(f"测试集MSE: {test_score:.4f}")

# 可视化
plt.figure(figsize=(10, 6))
plt.scatter(X_train, y_train, label='训练数据', alpha=0.6)
plt.scatter(X_test, y_test, label='测试数据', alpha=0.6)
plt.plot(X_plot, high_degree.predict(X_plot), 'r-', label='15次多项式', linewidth=2)
plt.plot(X_plot, np.sin(X_plot), 'g--', label='真实函数', linewidth=2)
plt.ylim(-2, 2)
plt.xlabel('X')
plt.ylabel('y')
plt.title('过拟合示例')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

模型复杂度分析

# 不同多项式次数的比较
degrees = [1, 3, 5, 10, 15]
train_scores = []
test_scores = []

for degree in degrees:
    model = Pipeline([
        ('poly', PolynomialFeatures(degree=degree)),
        ('linear', LinearRegression())
    ])
    model.fit(X_train, y_train)
    
    train_scores.append(mean_squared_error(y_train, model.predict(X_train)))
    test_scores.append(mean_squared_error(y_test, model.predict(X_test)))

plt.figure(figsize=(10, 6))
plt.plot(degrees, train_scores, 'bo-', label='训练集MSE')
plt.plot(degrees, test_scores, 'ro-', label='测试集MSE')
plt.xlabel('多项式次数')
plt.ylabel('MSE')
plt.title('模型复杂度与误差关系')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# 找到最佳次数
best_degree = degrees[np.argmin(test_scores)]
print(f"最佳多项式次数: {best_degree}")

学习曲线诊断

# 学习曲线
def plot_learning_curve(estimator, X, y, title="学习曲线"):
    train_sizes, train_scores, test_scores = learning_curve(
        estimator, X, y, cv=5, n_jobs=-1,
        train_sizes=np.linspace(0.1, 1.0, 10),
        scoring='neg_mean_squared_error'
    )
    
    train_mean = -train_scores.mean(axis=1)
    train_std = train_scores.std(axis=1)
    test_mean = -test_scores.mean(axis=1)
    test_std = test_scores.std(axis=1)
    
    plt.figure(figsize=(10, 6))
    plt.plot(train_sizes, train_mean, 'o-', label='训练误差')
    plt.fill_between(train_sizes, train_mean - train_std,
                     train_mean + train_std, alpha=0.1)
    plt.plot(train_sizes, test_mean, 'o-', label='验证误差')
    plt.fill_between(train_sizes, test_mean - test_std,
                     test_mean + test_std, alpha=0.1)
    plt.xlabel('训练样本数')
    plt.ylabel('误差')
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# 欠拟合的学习曲线
plot_learning_curve(LinearRegression(), X, y, "欠拟合学习曲线")

# 过拟合的学习曲线
plot_learning_curve(
    Pipeline([('poly', PolynomialFeatures(degree=15)), ('linear', LinearRegression())]),
    X, y, "过拟合学习曲线"
)

解决方案

解决欠拟合

# 增加模型复杂度
from sklearn.ensemble import GradientBoostingRegressor

# 使用更复杂的模型
gbr = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=42)
gbr.fit(X_train, y_train)

train_score = mean_squared_error(y_train, gbr.predict(X_train))
test_score = mean_squared_error(y_test, gbr.predict(X_test))

print("解决欠拟合(增加模型复杂度):")
print(f"训练集MSE: {train_score:.4f}")
print(f"测试集MSE: {test_score:.4f}")

解决过拟合

from sklearn.linear_model import Ridge, Lasso
from sklearn.model_selection import GridSearchCV

# 正则化
ridge = Ridge(alpha=1.0)
ridge.fit(PolynomialFeatures(degree=15).fit_transform(X_train), y_train)

train_score = mean_squared_error(y_train, 
    ridge.predict(PolynomialFeatures(degree=15).transform(X_train)))
test_score = mean_squared_error(y_test, 
    ridge.predict(PolynomialFeatures(degree=15).transform(X_test)))

print("\n解决过拟合(Ridge正则化):")
print(f"训练集MSE: {train_score:.4f}")
print(f"测试集MSE: {test_score:.4f}")

诊断总结

# 诊断指标
print("过拟合诊断:")
print("- 训练误差低,验证误差高")
print("- 模型复杂度过高")
print("- 解决方案:正则化、减少特征、增加数据")

print("\n欠拟合诊断:")
print("- 训练误差和验证误差都高")
print("- 模型复杂度过低")
print("- 解决方案:增加特征、使用更复杂模型、减少正则化")

总结

问题 表现 解决方案
欠拟合 训练/验证误差都高 增加模型复杂度、添加特征
过拟合 训练误差低,验证误差高 正则化、减少特征、增加数据

理解偏差-方差权衡是构建优秀机器学习模型的基础。