- 阅读权限
- 255
- 威望
- 1 级
- 论坛币
- 49517 个
- 通用积分
- 53.5804
- 学术水平
- 370 点
- 热心指数
- 273 点
- 信用等级
- 335 点
- 经验
- 57815 点
- 帖子
- 4006
- 精华
- 21
- 在线时间
- 582 小时
- 注册时间
- 2005-5-8
- 最后登录
- 2023-11-26
|
- def main(args: Array[String]): Unit = {
- val basePath = "/.../KNN_Example_1.csv"
- val testData = getDataFromCSV(new File(basePath))
- //Define the amount of rounds, in our case 2 and
- //initialise the cross validation
- val cv = new CrossValidation(testData._2.length, validationRounds)
- val testDataWithIndices = (testData
- ._1
- .zipWithIndex,
- testData
- ._2
- .zipWithIndex)
- val trainingDPSets = cv.train
- .map(indexList => indexList
- .map(index => testDataWithIndices
- ._1.collectFirst { case (dp, `index`) => dp}.get))
- val trainingClassifierSets = cv.train
- .map(indexList => indexList
- .map(index => testDataWithIndices
- ._2.collectFirst { case (dp, `index`) => dp}.get))
- val testingDPSets = cv.test
- .map(indexList => indexList
- .map(index => testDataWithIndices
- ._1.collectFirst { case (dp, `index`) => dp}.get))
- val testingClassifierSets = cv.test
- .map(indexList => indexList
- .map(index => testDataWithIndices
- ._2.collectFirst { case (dp, `index`) => dp}.get))
- val validationRoundRecords = trainingDPSets
- .zipWithIndex.map(x => ( x._1,
- trainingClassifierSets(x._2),
- testingDPSets(x._2),
- testingClassifierSets(x._2)
- )
- )
- validationRoundRecords
- .foreach { record =>
- val knn = KNN.learn(record._1, record._2, 3)
- //And for each test data point make a prediction with the model
- val predictions = record
- ._3
- .map(x => knn.predict(x))
- .zipWithIndex
- //Finally evaluate the predictions as correct or incorrect
- //and count the amount of wrongly classified data points.
- val error = predictions
- .map(x => if (x._1 != record._4(x._2)) 1 else 0)
- .sum
- println("False prediction rate: " + error / predictions.length * 100 + "%")
- }
- }
复制代码
|
|