K近邻分类算法详解与Python实现
一、K近邻算法原理
K近邻(KNN)是一种基于实例的学习算法,核心思想简单直观:给定一个未知类别的样本,在训练集中找到与该样本最近的K个邻居,然后根据这K个邻居的类别通过多数投票决定目标样本的类别。
具体步骤:
- 准备带有标签的训练数据集
- 计算待分类样本与所有训练样本的距离
- 选取距离最近的K个训练样本
- 统计这K个样本中各类别出现的频率
- 返回频率最高的类别作为预测结果
二、算法优劣分析
优势:
- 理论直观,实现简便
- 支持线性和非线性分类,也可用于回归任务
- 训练阶段复杂度低(O(n)),仅需存储数据
- 天然支持多分类和多标签问题
- 对类别边界重叠较多的数据表现良好
局限:
- 预测时计算开销大,尤其在高维特征空间中
- 对不平衡数据集敏感,多数类可能主导预测
- K值选择影响显著,过小易受噪声干扰
- 模型可解释性一般
三、完整代码实现
以下代码实现了KNN分类器及相关辅助功能:
#encoding:utf-8
from numpy import *
import operator
from os import listdir
def knn_classifier(test_sample, train_data, train_labels, k):
"""
K近邻分类器
参数:
test_sample: 待分类样本(1xN向量)
train_data: 训练数据集(MxN矩阵)
train_labels: 训练标签(Mx1向量)
k: 最近邻居数(建议奇数)
返回:
预测类别标签
"""
num_samples = train_data.shape[0]
# 计算欧氏距离
diff_matrix = tile(test_sample, (num_samples, 1)) - train_data
squared_diff = diff_matrix ** 2
squared_dist = squared_diff.sum(axis=1)
distances = squared_dist ** 0.5
# 获取排序后的索引
sorted_indices = distances.argsort()
# 统计K个最近邻的类别
vote_counter = {}
for i in range(k):
neighbor_label = train_labels[sorted_indices[i]]
vote_counter[neighbor_label] = vote_counter.get(neighbor_label, 0) + 1
# 按投票数降序排列,返回票数最高的类别
sorted_votes = sorted(vote_counter.items(), key=operator.itemgetter(1), reverse=True)
return sorted_votes[0][0]
def create_sample_dataset():
"""创建示例数据集"""
data = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return data, labels
def load_data_from_file(filename):
"""从文件加载数据"""
with open(filename) as file_handler:
lines = file_handler.readlines()
line_count = len(lines)
feature_matrix = zeros((line_count, 3))
label_vector = []
with open(filename) as file_handler:
index = 0
for line in file_handler:
line = line.strip()
parts = line.split('\t')
feature_matrix[index, :] = parts[0:3]
label_vector.append(int(parts[-1]))
index += 1
return feature_matrix, label_vector
def normalize_dataset(data):
"""数据归一化处理"""
min_values = data.min(0)
max_values = data.max(0)
ranges = max_values - min_values
normalized_data = zeros(shape(data))
row_count = data.shape[0]
normalized_data = data - tile(min_values, (row_count, 1))
normalized_data = normalized_data / tile(ranges, (row_count, 1))
return normalized_data, ranges, min_values
def evaluate_classifier():
"""评估分类器性能"""
test_ratio = 0.10
dating_data, dating_labels = load_data_from_file('datingTestSet2.txt')
norm_data, ranges, min_vals = normalize_dataset(dating_data)
total_samples = norm_data.shape[0]
test_count = int(total_samples * test_ratio)
error_count = 0.0
for i in range(test_count):
predicted = knn_classifier(norm_data[i,:],
norm_data[test_count:total_samples,:],
dating_labels[test_count:total_samples],
3)
print("预测结果: %d, 实际标签: %d" % (predicted, dating_labels[i]))
if predicted != dating_labels[i]:
error_count += 1.0
print("总错误率: %f" % (error_count / float(test_count)))
def image_to_vector(filename):
"""将32x32图像转换为1x1024向量"""
result_vector = zeros((1, 1024))
with open(filename) as file_handler:
for i in range(32):
line = file_handler.readline()
for j in range(32):
result_vector[0, 32*i + j] = int(line[j])
return result_vector
def handwritten_digit_recognition():
"""手写数字识别测试"""
hw_labels = []
training_files = listdir('trainingDigits')
m = len(training_files)
training_matrix = zeros((m, 1024))
for i in range(m):
filename = training_files[i]
class_number = int(filename.split('.')[0].split('_')[0])
hw_labels.append(class_number)
training_matrix[i,:] = image_to_vector('trainingDigits/%s' % filename)
test_files = listdir('testDigits')
error_count = 0.0
test_count = len(test_files)
for i in range(test_count):
filename = test_files[i]
class_number = int(filename.split('.')[0].split('_')[0])
test_vector = image_to_vector('testDigits/%s' % filename)
predicted = knn_classifier(test_vector, training_matrix, hw_labels, 3)
print("预测: %d, 实际: %d" % (predicted, class_number))
if predicted != class_number:
error_count += 1.0
print("错误数: %d" % error_count)
print("错误率: %f" % (error_count / float(test_count)))
def test_core_functions():
"""测试核心功能函数"""
# 测试矩阵运算
training_matrix = zeros((10, 4))
for i in range(10):
training_matrix[i,:] = [1, 2, 3, 4]
# 测试距离计算
sample_array = array([[3, 2, 3, 4], [3, 4, 5, 6]])
squared_array = sample_array ** 2
# 测试排序功能
sample_values = array([1, 4, 3, -1, 6, 9])
sorted_indices = sample_values.argsort()
# 测试投票统计
vote_counts = {0: 3, 5: 2, 4: 6}
sorted_votes = sorted(vote_counts.items(), key=operator.itemgetter(1), reverse=True)
if __name__ == '__main__':
# 运行手写数字识别测试
handwritten_digit_recognition()
# 运行约会数据集分类测试
evaluate_classifier()
四、可视化数据分布
以下代码展示了如何利用matplotlib绘制特征散点图,以直观了解数据分布:
from numpy import *
import matplotlib.pyplot as plt
# 生成模拟数据
sample_count = 1000
category1_x = []
category1_y = []
category2_x = []
category2_y = []
category3_x = []
category3_y = []
point_sizes = []
point_colors = []
output_file = open('testSet.txt', 'w')
for i in range(sample_count):
r0, r1 = random.standard_normal(2)
class_value = random.uniform(0, 1)
if class_value <= 0.16:
miles = random.uniform(22000, 60000)
game_time = 3 + 1.6 * r1
point_sizes.append(20)
point_colors.append(2.1)
category1_x.append(miles)
category1_y.append(game_time)
elif class_value <= 0.33:
miles = 6000 * r0 + 70000
game_time = 10 + 3 * r1 + 2 * r0
point_sizes.append(20)
point_colors.append(1.1)
if game_time < 0: game_time = 0
if miles < 0: miles = 0
category1_x.append(miles)
category1_y.append(game_time)
elif class_value <= 0.66:
miles = 5000 * r0 + 10000
game_time = 3 + 2.8 * r1
point_sizes.append(30)
point_colors.append(1.1)
if game_time < 0: game_time = 0
if miles < 0: miles = 0
category2_x.append(miles)
category2_y.append(game_time)
else:
miles = 10000 * r0 + 35000
game_time = 10 + 2.0 * r1
point_sizes.append(50)
point_colors.append(0.1)
if game_time < 0: game_time = 0
if miles < 0: miles = 0
category3_x.append(miles)
category3_y.append(game_time)
output_file.close()
# 创建可视化图表
plt.figure(figsize=(10, 6))
scatter1 = plt.scatter(category1_x, category1_y, s=20, c='red', label='不喜欢')
scatter2 = plt.scatter(category2_x, category2_y, s=30, c='green', label='一般喜欢')
scatter3 = plt.scatter(category3_x, category3_y, s=50, c='blue', label='非常喜欢')
plt.legend(loc=2)
plt.axis([-5000, 100000, -2, 25])
plt.xlabel('每年飞行里程数')
plt.ylabel('玩视频游戏时间占比')
plt.title('约会数据散点图')
plt.show()