当前位置:首页 > 技术 > 正文内容

机器学习分类评估指标详解:准确率、精确率、召回率与F1值

访客 技术 2026年7月4日 2

引言

在机器学习,特别是分类任务中,评估模型的性能至关重要。准确率、精确率、召回率和F1值是衡量分类模型效果的核心指标。理解这些指标的定义、计算方式及其相互关系,是优化模型和选择最佳方案的基础。

核心概念:混淆矩阵

这些指标的计算都基于一个名为"混淆矩阵"的表格,它展示了模型预测结果与真实标签之间的对应关系。在二分类问题中,我们定义:

  • 真正例 (TP - True Positive):模型预测为正例,且实际也是正例。
  • 假正例 (FP - False Positive):模型预测为正例,但实际是负例。
  • 假负例 (FN - False Negative):模型预测为负例,但实际是正例。
  • 真负例 (TN - True Negative):模型预测为负例,且实际也是负例。

为了更好地理解,我们假设一个场景:一个水果分类器需要区分苹果(正例)和橙子(负例)。模型对10个水果进行预测,结果如下:


# 真实标签: 1=苹果, 0=橙子
# 预测结果: 1=苹果, 0=橙子
actual_fruits = [1, 1, 0, 1, 0, 0, 1, 0, 0, 1]
predicted_fruits = [1, 0, 0, 1, 1, 0, 1, 0, 1, 1]

通过对比,我们可以统计出:TP=4, FP=2, FN=1, TN=3。

关键评估指标

1. 准确率 (Accuracy)

准确率衡量的是模型预测正确的总比例,即所有预测结果中,有多少是正确的。公式为:

Accuracy = (TP + TN) / (TP + FP + FN + TN)

在我们的水果例子中,准确率为 (4 + 3) / 10 = 0.7,即70%的预测是正确的。

2. 精确率 (Precision)

精确率关注的是模型预测为正例的结果中,有多少是真正正例。它衡量了预测为正例的"准度"。公式为:

Precision = TP / (TP + FP)

在水果例子中,精确率为 4 / (4 + 2) = 0.67,意味着当模型预测为苹果时,有67%的概率是正确的。

3. 召回率 (Recall)

召回率关注的是所有真实正例中,有多少被模型成功预测出来。它衡量了模型"找出"所有正例的能力。公式为:

Recall = TP / (TP + FN)

在水果例子中,召回率为 4 / (4 + 1) = 0.8,意味着所有真实的苹果中,有80%被模型成功识别。

4. F1值

精确率和召回率往往需要权衡。F1值是精确率和召回率的调和平均值,旨在综合两者,提供一个单一的评估分数。公式为:

F1 = 2 * (Precision * Recall) / (Precision + Recall)

在水果例子中,F1值为 2 * (0.67 * 0.8) / (0.67 + 0.8) ≈ 0.73。

代码实现

使用 Scikit-learn


from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# 真实标签和预测结果
true_labels = [1, 1, 0, 1, 0, 0, 1, 0, 0, 1]
predicted_labels = [1, 0, 0, 1, 1, 0, 1, 0, 1, 1]

# 计算各项指标
acc = accuracy_score(true_labels, predicted_labels)
prec = precision_score(true_labels, predicted_labels)
rec = recall_score(true_labels, predicted_labels)
f1 = f1_score(true_labels, predicted_labels)

print(f"准确率: {acc:.2f}")
print(f"精确率: {prec:.2f}")
print(f"召回率: {rec:.2f}")
print(f"F1值: {f1:.2f}")

使用 PySpark


from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql import SparkSession

# 初始化Spark
spark = SparkSession.builder.appName("MetricsExample").getOrCreate()
sc = spark.sparkContext

# 准备数据 (预测值, 真实值)
prediction_and_labels = sc.parallelize([
    (1.0, 1.0), (1.0, 1.0), (0.0, 0.0), (1.0, 1.0), (1.0, 0.0),
    (0.0, 0.0), (1.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0)
])

# 计算指标
metrics = MulticlassMetrics(prediction_and_labels)

print(f"准确率: {metrics.accuracy:.2f}")
print(f"精确率 (苹果): {metrics.precision(1.0):.2f}")
print(f"召回率 (苹果): {metrics.recall(1.0):.2f}")
print(f"F1值 (苹果): {metrics.fMeasure(1.0):.2f}")

使用 TensorFlow


import tensorflow as tf

# 定义真实值和预测值
true_values = tf.constant([1, 1, 0, 1, 0, 0, 1, 0, 0, 1], dtype=tf.float32)
predicted_values = tf.constant([1, 0, 0, 1, 1, 0, 1, 0, 1, 1], dtype=tf.float32)

# 计算各项指标
accuracy_metric = tf.keras.metrics.Accuracy()
precision_metric = tf.keras.metrics.Precision()
recall_metric = tf.keras.metrics.Recall()

accuracy_metric.update_state(true_values, predicted_values)
precision_metric.update_state(true_values, predicted_values)
recall_metric.update_state(true_values, predicted_values)

acc = accuracy_metric.result().numpy()
prec = precision_metric.result().numpy()
rec = recall_metric.result().numpy()
f1 = 2 * (prec * rec) / (prec + rec + 1e-7) # 防止除以零

print(f"准确率: {acc:.2f}")
print(f"精确率: {prec:.2f}")
print(f"召回率: {rec:.2f}")
print(f"F1值: {f1:.2f}")

多分类场景下的应用

对于多分类问题(例如,识别苹果、橙子、香蕉),我们可以将每个类别视为一个独立的二分类问题(例如,"这是苹果吗?"),然后计算每个类别的精确率和召回率,最后通过"宏平均"或"微平均"来得到整体的评估指标。

宏平均是先计算每个类别的指标,再求平均值,它 treats all classes equally。微平均则是先汇总所有类别的TP、FP、FN,再进行计算,它更受大类别的影响。


from sklearn.metrics import classification_report

# 假设我们有三个类别:0, 1, 2
y_true_multi = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
y_pred_multi = [0, 2, 1, 0, 1, 2, 0, 0, 2, 1]

# 使用classification_report可以方便地获取宏平均和微平均
report = classification_report(y_true_multi, y_pred_multi, output_dict=True)
print(f"宏平均精确率: {report['macro avg']['precision']:.2f}")
print(f"宏平均召回率: {report['macro avg']['recall']:.2f}")
print(f"宏平均F1值: {report['macro avg']['f1-score']:.2f}")

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。