当前位置:首页 > 技术 > 正文内容

Spark MLlib 支持向量机分类实战

访客 技术 2026年5月27日 3

数据集说明

支持向量机算法需要特定格式的输入数据。数据采用LIBSVM格式,每行包含一个标签和多个特征值对:

标签 特征索引1:特征值1 特征索引2:特征值2 ...

示例数据:

0 128:51 129:159 130:253 131:159 132:50 ...
1 159:124 160:253 161:255 162:63 ...

基础训练与评估实现

以下代码演示了如何加载数据、训练SVM模型并计算预测精度:

// 加载LIBSVM格式数据集
val datasetPath = "/user/tmp/sample_libsvm_data.txt"
val rawData = MLUtils.loadLibSVMFile(sc, datasetPath).cache()

// 划分训练集和测试集
val dataSplits = rawData.randomSplit(Array(0.6, 0.4), seed = 11L)
val trainSet = dataSplits(0).cache()
val testSet = dataSplits(1)

println(s"训练样本数: ${trainSet.count()}, 测试样本数: ${testSet.count()}")

// 配置训练参数
val maxIterations = 1000
val learningRate = 1
val batchSizeRatio = 1.0

// 训练SVM模型
val trainedModel = SVMWithSGD.train(trainSet, maxIterations, learningRate, batchSizeRatio)

// 执行预测
val predictions = trainedModel.predict(testSet.map(_.features))
val predictionLabels = predictions.zip(testSet.map(_.label))

// 计算准确率
val evaluator = new MulticlassMetrics(predictionLabels)
val accuracy = evaluator.precision
println("模型准确率 = " + accuracy)

二分类评估指标计算

使用ROC曲线下面积等指标评估模型性能:

import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils

// 加载并分割数据
val input = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val partitions = input.randomSplit(Array(0.6, 0.4), seed = 11L)
val trainingData = partitions(0).cache()
val testData = partitions(1)

// 模型训练
val iterationCount = 100
val classifier = SVMWithSGD.train(trainingData, iterationCount)

// 清除默认阈值以获得原始分数
classifier.clearThreshold()

// 计算测试集得分
val scoresAndLabels = testData.map { sample =>
  val score = classifier.predict(sample.features)
  (score, sample.label)
}

// 计算AUC指标
val metrics = new BinaryClassificationMetrics(scoresAndLabels)
val aucValue = metrics.areaUnderROC()
println("ROC曲线下面积 = " + aucValue)

// 模型持久化
classifier.save(sc, "modelStoragePath")
val loadedModel = SVMModel.load(sc, "modelStoragePath")

自定义正则化参数

SVM默认使用L2正则化,可通过以下方式修改为L1正则化:

import org.apache.spark.mllib.optimization.L1Updater

val svmTrainer = new SVMWithSGD()
svmTrainer.optimizer
  .setNumIterations(200)
  .setRegParam(0.1)
  .setUpdater(new L1Updater)
  
val l1RegularizedModel = svmTrainer.run(trainingData)

处理普通文本格式数据

对于非LIBSVM格式的数据文件,需要先进行预处理:

import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint

// 数据示例格式:
// 0 2.857 0 2.061 2.619 ...
// 1 2.857 0 2.061 2.619 ...

// 加载并解析数据
val rawDataFile = sc.textFile("mllib/data/sample_svm_data.txt")
val processedData = rawDataFile.map { line =>
  val elements = line.split(' ')
  LabeledPoint(elements(0).toDouble, Vectors.dense(elements.tail.map(_.toDouble)))
}

// 训练模型
val iterations = 20
val svmModel = SVMWithSGD.train(processedData, iterations)

// 计算训练误差
val labelsAndPredictions = processedData.map { point =>
  val prediction = svmModel.predict(point.features)
  (point.label, prediction)
}

val errorRate = labelsAndPredictions.filter(r => r._1 != r._2).count().toDouble / processedData.count()
println("训练误差 = " + errorRate)
SVM分类结果可视化
标签: SparkMLlibSVM

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Mac 安装 Node.js 指南

方法一:通过官网安装包(最简单,适合初学者)如果你只是想快速安装并开始使用,这是最直接的方法。访问 Node.js 官网。页面会显示两个版本:LTS (Recommended For Most Users):长期支持版,最稳定。建议选这个。Current:最新特性版,包含最新功能但可能不够稳定。下载 .pkg 安装包并运行。按照安装向导点击“下一步”即可完成。方法二:使用 Homebrew 安装(...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。