当前位置:首页 > 技术 > 正文内容

K近邻分类算法详解与Python实现

访客 技术 2026年7月5日 4

一、K近邻算法原理

K近邻(KNN)是一种基于实例的学习算法,核心思想简单直观:给定一个未知类别的样本,在训练集中找到与该样本最近的K个邻居,然后根据这K个邻居的类别通过多数投票决定目标样本的类别。

具体步骤:

  • 准备带有标签的训练数据集
  • 计算待分类样本与所有训练样本的距离
  • 选取距离最近的K个训练样本
  • 统计这K个样本中各类别出现的频率
  • 返回频率最高的类别作为预测结果

二、算法优劣分析

优势:

  • 理论直观,实现简便
  • 支持线性和非线性分类,也可用于回归任务
  • 训练阶段复杂度低(O(n)),仅需存储数据
  • 天然支持多分类和多标签问题
  • 对类别边界重叠较多的数据表现良好

局限:

  • 预测时计算开销大,尤其在高维特征空间中
  • 对不平衡数据集敏感,多数类可能主导预测
  • K值选择影响显著,过小易受噪声干扰
  • 模型可解释性一般

三、完整代码实现

以下代码实现了KNN分类器及相关辅助功能:

#encoding:utf-8
from numpy import *
import operator
from os import listdir

def knn_classifier(test_sample, train_data, train_labels, k):
    """
    K近邻分类器
    参数:
        test_sample: 待分类样本(1xN向量)
        train_data: 训练数据集(MxN矩阵)
        train_labels: 训练标签(Mx1向量)
        k: 最近邻居数(建议奇数)
    返回:
        预测类别标签
    """
    num_samples = train_data.shape[0]
    # 计算欧氏距离
    diff_matrix = tile(test_sample, (num_samples, 1)) - train_data
    squared_diff = diff_matrix ** 2
    squared_dist = squared_diff.sum(axis=1)
    distances = squared_dist ** 0.5
    
    # 获取排序后的索引
    sorted_indices = distances.argsort()
    
    # 统计K个最近邻的类别
    vote_counter = {}
    for i in range(k):
        neighbor_label = train_labels[sorted_indices[i]]
        vote_counter[neighbor_label] = vote_counter.get(neighbor_label, 0) + 1
    
    # 按投票数降序排列,返回票数最高的类别
    sorted_votes = sorted(vote_counter.items(), key=operator.itemgetter(1), reverse=True)
    return sorted_votes[0][0]

def create_sample_dataset():
    """创建示例数据集"""
    data = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return data, labels

def load_data_from_file(filename):
    """从文件加载数据"""
    with open(filename) as file_handler:
        lines = file_handler.readlines()
        line_count = len(lines)
    
    feature_matrix = zeros((line_count, 3))
    label_vector = []
    
    with open(filename) as file_handler:
        index = 0
        for line in file_handler:
            line = line.strip()
            parts = line.split('\t')
            feature_matrix[index, :] = parts[0:3]
            label_vector.append(int(parts[-1]))
            index += 1
    
    return feature_matrix, label_vector

def normalize_dataset(data):
    """数据归一化处理"""
    min_values = data.min(0)
    max_values = data.max(0)
    ranges = max_values - min_values
    normalized_data = zeros(shape(data))
    row_count = data.shape[0]
    normalized_data = data - tile(min_values, (row_count, 1))
    normalized_data = normalized_data / tile(ranges, (row_count, 1))
    return normalized_data, ranges, min_values

def evaluate_classifier():
    """评估分类器性能"""
    test_ratio = 0.10
    dating_data, dating_labels = load_data_from_file('datingTestSet2.txt')
    norm_data, ranges, min_vals = normalize_dataset(dating_data)
    
    total_samples = norm_data.shape[0]
    test_count = int(total_samples * test_ratio)
    error_count = 0.0
    
    for i in range(test_count):
        predicted = knn_classifier(norm_data[i,:], 
                                    norm_data[test_count:total_samples,:],
                                    dating_labels[test_count:total_samples], 
                                    3)
        print("预测结果: %d, 实际标签: %d" % (predicted, dating_labels[i]))
        if predicted != dating_labels[i]:
            error_count += 1.0
    
    print("总错误率: %f" % (error_count / float(test_count)))

def image_to_vector(filename):
    """将32x32图像转换为1x1024向量"""
    result_vector = zeros((1, 1024))
    with open(filename) as file_handler:
        for i in range(32):
            line = file_handler.readline()
            for j in range(32):
                result_vector[0, 32*i + j] = int(line[j])
    return result_vector

def handwritten_digit_recognition():
    """手写数字识别测试"""
    hw_labels = []
    training_files = listdir('trainingDigits')
    m = len(training_files)
    training_matrix = zeros((m, 1024))
    
    for i in range(m):
        filename = training_files[i]
        class_number = int(filename.split('.')[0].split('_')[0])
        hw_labels.append(class_number)
        training_matrix[i,:] = image_to_vector('trainingDigits/%s' % filename)
    
    test_files = listdir('testDigits')
    error_count = 0.0
    test_count = len(test_files)
    
    for i in range(test_count):
        filename = test_files[i]
        class_number = int(filename.split('.')[0].split('_')[0])
        test_vector = image_to_vector('testDigits/%s' % filename)
        predicted = knn_classifier(test_vector, training_matrix, hw_labels, 3)
        print("预测: %d, 实际: %d" % (predicted, class_number))
        if predicted != class_number:
            error_count += 1.0
    
    print("错误数: %d" % error_count)
    print("错误率: %f" % (error_count / float(test_count)))

def test_core_functions():
    """测试核心功能函数"""
    # 测试矩阵运算
    training_matrix = zeros((10, 4))
    for i in range(10):
        training_matrix[i,:] = [1, 2, 3, 4]
    
    # 测试距离计算
    sample_array = array([[3, 2, 3, 4], [3, 4, 5, 6]])
    squared_array = sample_array ** 2
    
    # 测试排序功能
    sample_values = array([1, 4, 3, -1, 6, 9])
    sorted_indices = sample_values.argsort()
    
    # 测试投票统计
    vote_counts = {0: 3, 5: 2, 4: 6}
    sorted_votes = sorted(vote_counts.items(), key=operator.itemgetter(1), reverse=True)

if __name__ == '__main__':
    # 运行手写数字识别测试
    handwritten_digit_recognition()
    # 运行约会数据集分类测试
    evaluate_classifier()

四、可视化数据分布

以下代码展示了如何利用matplotlib绘制特征散点图,以直观了解数据分布:

from numpy import *
import matplotlib.pyplot as plt

# 生成模拟数据
sample_count = 1000
category1_x = []
category1_y = []
category2_x = []
category2_y = []
category3_x = []
category3_y = []
point_sizes = []
point_colors = []

output_file = open('testSet.txt', 'w')

for i in range(sample_count):
    r0, r1 = random.standard_normal(2)
    class_value = random.uniform(0, 1)
    
    if class_value <= 0.16:
        miles = random.uniform(22000, 60000)
        game_time = 3 + 1.6 * r1
        point_sizes.append(20)
        point_colors.append(2.1)
        category1_x.append(miles)
        category1_y.append(game_time)
        
    elif class_value <= 0.33:
        miles = 6000 * r0 + 70000
        game_time = 10 + 3 * r1 + 2 * r0
        point_sizes.append(20)
        point_colors.append(1.1)
        if game_time < 0: game_time = 0
        if miles < 0: miles = 0
        category1_x.append(miles)
        category1_y.append(game_time)
        
    elif class_value <= 0.66:
        miles = 5000 * r0 + 10000
        game_time = 3 + 2.8 * r1
        point_sizes.append(30)
        point_colors.append(1.1)
        if game_time < 0: game_time = 0
        if miles < 0: miles = 0
        category2_x.append(miles)
        category2_y.append(game_time)
        
    else:
        miles = 10000 * r0 + 35000
        game_time = 10 + 2.0 * r1
        point_sizes.append(50)
        point_colors.append(0.1)
        if game_time < 0: game_time = 0
        if miles < 0: miles = 0
        category3_x.append(miles)
        category3_y.append(game_time)

output_file.close()

# 创建可视化图表
plt.figure(figsize=(10, 6))
scatter1 = plt.scatter(category1_x, category1_y, s=20, c='red', label='不喜欢')
scatter2 = plt.scatter(category2_x, category2_y, s=30, c='green', label='一般喜欢')
scatter3 = plt.scatter(category3_x, category3_y, s=50, c='blue', label='非常喜欢')

plt.legend(loc=2)
plt.axis([-5000, 100000, -2, 25])
plt.xlabel('每年飞行里程数')
plt.ylabel('玩视频游戏时间占比')
plt.title('约会数据散点图')
plt.show()
标签: K近邻算法

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。