逻辑回归详解与实战
逻辑回归是机器学习中一个重要的二分类算法。尽管名字中有"回归"二字,但它的主要用途是解决分类问题。本文将深入探讨逻辑回归的原理,并通过实际案例展示其应用。
逻辑回归的核心概念
逻辑回归主要用于处理二分类问题,例如判断邮件是否为垃圾邮件或预测用户是否会购买商品。它通过结合线性模型和Sigmoid函数来实现分类任务。
逻辑回归与线性回归的区别
| 对比维度 | 线性回归 | 逻辑回归 |
|---|---|---|
| 任务类型 | 预测连续值 | 预测离散类别 |
| 输出范围 | 任意实数 | 0-1之间的概率 |
| 核心目标 | 最小化误差平方和 | 最小化交叉熵损失 |
为什么不能用线性回归做分类?
- 输出范围不合适:线性回归可能产生超出[0,1]范围的值,这在分类问题中没有意义。
- 对异常值敏感:异常值可能导致线性回归拟合效果变差。
逻辑回归通过Sigmoid函数解决了这些问题,将线性输出映射到[0,1]区间。
逻辑回归的工作原理
逻辑回归分为三个步骤:
- 线性组合计算得分。
- 使用Sigmoid函数转换得分。
- 定义损失函数优化模型。
线性组合计算得分
公式为:
z = w₁x₁ + w₂x₂ + ... + wₙxₙ + b
其中:
x₁~xₙ:输入特征。w₁~wₙ:特征权重。b:偏置项。z:线性得分。
Sigmoid函数
Sigmoid函数将线性得分压缩到[0,1]之间,表示样本属于正类的概率。
Sigmoid函数的数学表达式
σ(z) = 1 / (1 + e^(-z))
绘制Sigmoid函数曲线
import numpy as np
import matplotlib.pyplot as plt
def sigmoid(x):
return 1 / (1 + np.exp(-x))
x = np.linspace(-10, 10, 100)
y = sigmoid(x)
plt.plot(x, y, label='Sigmoid Function', color='blue')
plt.axhline(y=0.5, color='red', linestyle='--', label='Probability = 0.5')
plt.xlabel('Input')
plt.ylabel('Output Probability')
plt.title('Sigmoid Function Visualization')
plt.legend()
plt.grid(True)
plt.show()
损失函数
逻辑回归采用交叉熵损失函数来衡量预测值与真实值之间的差距。
交叉熵损失函数公式
L(y, ŷ) = -[y·log(ŷ) + (1-y)·log(1-ŷ)]
实战演练:乳腺癌诊断
我们将使用Scikit-learn中的乳腺癌数据集进行逻辑回归的实战演练。
数据准备
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
data = load_breast_cancer()
X = data.data
y = data.target
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.3, random_state=42, stratify=y
)
模型训练
model = LogisticRegression(max_iter=2000, random_state=42)
model.fit(X_train, y_train)
参数解读
weights = model.coef_[0]
bias = model.intercept_[0]
feature_names = data.feature_names
weight_df = pd.DataFrame({'Feature': feature_names, 'Weight': weights})
print(weight_df.sort_values(by='Weight', ascending=False).head(10))
模型评估
y_pred = model.predict(X_test)
y_prob = model.predict_proba(X_test)[:, 1]
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
常见问题及解决方案
- 模型不收敛:增加迭代次数或调整正则化参数。
- 类别不平衡:调整分类阈值或使用过采样/欠采样技术。
- 参数解读错误:确保特征已标准化。
总结
逻辑回归因其简单性和高效性,在工业界得到了广泛应用。通过本文的学习,你应该能够理解逻辑回归的基本原理并将其应用于实际问题中。