- 阅读权限
- 255
- 威望
- 1 级
- 论坛币
- 49407 个
- 通用积分
- 51.8704
- 学术水平
- 370 点
- 热心指数
- 273 点
- 信用等级
- 335 点
- 经验
- 57815 点
- 帖子
- 4006
- 精华
- 21
- 在线时间
- 582 小时
- 注册时间
- 2005-5-8
- 最后登录
- 2023-11-26
学术权威
还不是VIP/贵宾
TA的文库 其他... R资源总汇
Panel Data Analysis
Experimental Design
- 威望
- 1 级
- 论坛币
- 49407 个
- 通用积分
- 51.8704
- 学术水平
- 370 点
- 热心指数
- 273 点
- 信用等级
- 335 点
- 经验
- 57815 点
- 帖子
- 4006
- 精华
- 21
- 在线时间
- 582 小时
- 注册时间
- 2005-5-8
- 最后登录
- 2023-11-26
| 开心 2017-10-21 10:25:33 |
---|
签到天数: 1 天 连续签到: 1 天 [LV.1]初来乍到
|
经管之家送您一份
应届毕业生专属福利!
求职就业群
感谢您参与论坛问题回答
经管之家送您两个论坛币!
+2 论坛币
- package org.apache.spark.examples.ml
- import org.apache.spark.sql.SQLContext
- import org.apache.spark.{SparkContext, SparkConf}
- // $example on$
- import org.apache.spark.ml.Pipeline
- import org.apache.spark.ml.classification.DecisionTreeClassifier
- import org.apache.spark.ml.classification.DecisionTreeClassificationModel
- import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer}
- import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- // $example off$
- object DecisionTreeClassificationExample {
- def main(args: Array[String]): Unit = {
- val conf = new SparkConf().setAppName("DecisionTreeClassificationExample")
- val sc = new SparkContext(conf)
- val sqlContext = new SQLContext(sc)
- // $example on$
- // Load the data stored in LIBSVM format as a DataFrame.
- val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
- // Index labels, adding metadata to the label column.
- // Fit on whole dataset to include all labels in index.
- val labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("indexedLabel")
- .fit(data)
- // Automatically identify categorical features, and index them.
- val featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4) // features with > 4 distinct values are treated as continuous
- .fit(data)
- // Split the data into training and test sets (30% held out for testing)
- val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
- // Train a DecisionTree model.
- val dt = new DecisionTreeClassifier()
- .setLabelCol("indexedLabel")
- .setFeaturesCol("indexedFeatures")
- // Convert indexed labels back to original labels.
- val labelConverter = new IndexToString()
- .setInputCol("prediction")
- .setOutputCol("predictedLabel")
- .setLabels(labelIndexer.labels)
- // Chain indexers and tree in a Pipeline
- val pipeline = new Pipeline()
- .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
- // Train model. This also runs the indexers.
- val model = pipeline.fit(trainingData)
- // Make predictions.
- val predictions = model.transform(testData)
- // Select example rows to display.
- predictions.select("predictedLabel", "label", "features").show(5)
- // Select (prediction, true label) and compute test error
- val evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("indexedLabel")
- .setPredictionCol("prediction")
- .setMetricName("precision")
- val accuracy = evaluator.evaluate(predictions)
- println("Test Error = " + (1.0 - accuracy))
- val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
- println("Learned classification tree model:\n" + treeModel.toDebugString)
- // $example off$
- }
- }
复制代码
扫码加我 拉你入群
请注明:姓名-公司-职位
以便审核进群资格,未注明则拒绝
|
|
|