Spark MLlib 支持向量机分类实战
数据集说明
支持向量机算法需要特定格式的输入数据。数据采用LIBSVM格式,每行包含一个标签和多个特征值对:
标签 特征索引1:特征值1 特征索引2:特征值2 ...
示例数据:
0 128:51 129:159 130:253 131:159 132:50 ... 1 159:124 160:253 161:255 162:63 ...
基础训练与评估实现
以下代码演示了如何加载数据、训练SVM模型并计算预测精度:
// 加载LIBSVM格式数据集
val datasetPath = "/user/tmp/sample_libsvm_data.txt"
val rawData = MLUtils.loadLibSVMFile(sc, datasetPath).cache()
// 划分训练集和测试集
val dataSplits = rawData.randomSplit(Array(0.6, 0.4), seed = 11L)
val trainSet = dataSplits(0).cache()
val testSet = dataSplits(1)
println(s"训练样本数: ${trainSet.count()}, 测试样本数: ${testSet.count()}")
// 配置训练参数
val maxIterations = 1000
val learningRate = 1
val batchSizeRatio = 1.0
// 训练SVM模型
val trainedModel = SVMWithSGD.train(trainSet, maxIterations, learningRate, batchSizeRatio)
// 执行预测
val predictions = trainedModel.predict(testSet.map(_.features))
val predictionLabels = predictions.zip(testSet.map(_.label))
// 计算准确率
val evaluator = new MulticlassMetrics(predictionLabels)
val accuracy = evaluator.precision
println("模型准确率 = " + accuracy)
二分类评估指标计算
使用ROC曲线下面积等指标评估模型性能:
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils
// 加载并分割数据
val input = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val partitions = input.randomSplit(Array(0.6, 0.4), seed = 11L)
val trainingData = partitions(0).cache()
val testData = partitions(1)
// 模型训练
val iterationCount = 100
val classifier = SVMWithSGD.train(trainingData, iterationCount)
// 清除默认阈值以获得原始分数
classifier.clearThreshold()
// 计算测试集得分
val scoresAndLabels = testData.map { sample =>
val score = classifier.predict(sample.features)
(score, sample.label)
}
// 计算AUC指标
val metrics = new BinaryClassificationMetrics(scoresAndLabels)
val aucValue = metrics.areaUnderROC()
println("ROC曲线下面积 = " + aucValue)
// 模型持久化
classifier.save(sc, "modelStoragePath")
val loadedModel = SVMModel.load(sc, "modelStoragePath")
自定义正则化参数
SVM默认使用L2正则化,可通过以下方式修改为L1正则化:
import org.apache.spark.mllib.optimization.L1Updater
val svmTrainer = new SVMWithSGD()
svmTrainer.optimizer
.setNumIterations(200)
.setRegParam(0.1)
.setUpdater(new L1Updater)
val l1RegularizedModel = svmTrainer.run(trainingData)
处理普通文本格式数据
对于非LIBSVM格式的数据文件,需要先进行预处理:
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
// 数据示例格式:
// 0 2.857 0 2.061 2.619 ...
// 1 2.857 0 2.061 2.619 ...
// 加载并解析数据
val rawDataFile = sc.textFile("mllib/data/sample_svm_data.txt")
val processedData = rawDataFile.map { line =>
val elements = line.split(' ')
LabeledPoint(elements(0).toDouble, Vectors.dense(elements.tail.map(_.toDouble)))
}
// 训练模型
val iterations = 20
val svmModel = SVMWithSGD.train(processedData, iterations)
// 计算训练误差
val labelsAndPredictions = processedData.map { point =>
val prediction = svmModel.predict(point.features)
(point.label, prediction)
}
val errorRate = labelsAndPredictions.filter(r => r._1 != r._2).count().toDouble / processedData.count()
println("训练误差 = " + errorRate)