- 阅读权限
- 255
- 威望
- 0 级
- 论坛币
- 49957 个
- 通用积分
- 79.5487
- 学术水平
- 253 点
- 热心指数
- 300 点
- 信用等级
- 208 点
- 经验
- 41518 点
- 帖子
- 3256
- 精华
- 14
- 在线时间
- 766 小时
- 注册时间
- 2006-5-4
- 最后登录
- 2022-11-6
院士
还不是VIP/贵宾
TA的文库 其他... Bayesian NewOccidental
Spatial Data Analysis
东西方数据挖掘
- 威望
- 0 级
- 论坛币
- 49957 个
- 通用积分
- 79.5487
- 学术水平
- 253 点
- 热心指数
- 300 点
- 信用等级
- 208 点
- 经验
- 41518 点
- 帖子
- 3256
- 精华
- 14
- 在线时间
- 766 小时
- 注册时间
- 2006-5-4
- 最后登录
- 2022-11-6
|
经管之家送您一份
应届毕业生专属福利!
求职就业群
感谢您参与论坛问题回答
经管之家送您两个论坛币!
+2 论坛币
- */
- // scalastyle:off println
- package org.apache.spark.examples.ml;
- // $example on$
- import org.apache.spark.SparkConf;
- import org.apache.spark.api.java.JavaSparkContext;
- import org.apache.spark.ml.Pipeline;
- import org.apache.spark.ml.PipelineModel;
- import org.apache.spark.ml.PipelineStage;
- import org.apache.spark.ml.classification.DecisionTreeClassifier;
- import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
- import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
- import org.apache.spark.ml.feature.*;
- import org.apache.spark.sql.DataFrame;
- import org.apache.spark.sql.SQLContext;
- // $example off$
- public class JavaDecisionTreeClassificationExample {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample");
- JavaSparkContext jsc = new JavaSparkContext(conf);
- SQLContext sqlContext = new SQLContext(jsc);
- // $example on$
- // Load the data stored in LIBSVM format as a DataFrame.
- DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
- // Index labels, adding metadata to the label column.
- // Fit on whole dataset to include all labels in index.
- StringIndexerModel labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("indexedLabel")
- .fit(data);
- // Automatically identify categorical features, and index them.
- VectorIndexerModel featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4) // features with > 4 distinct values are treated as continuous
- .fit(data);
- // Split the data into training and test sets (30% held out for testing)
- DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
- DataFrame trainingData = splits[0];
- DataFrame testData = splits[1];
- // Train a DecisionTree model.
- DecisionTreeClassifier dt = new DecisionTreeClassifier()
- .setLabelCol("indexedLabel")
- .setFeaturesCol("indexedFeatures");
- // Convert indexed labels back to original labels.
- IndexToString labelConverter = new IndexToString()
- .setInputCol("prediction")
- .setOutputCol("predictedLabel")
- .setLabels(labelIndexer.labels());
- // Chain indexers and tree in a Pipeline
- Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});
- // Train model. This also runs the indexers.
- PipelineModel model = pipeline.fit(trainingData);
- // Make predictions.
- DataFrame predictions = model.transform(testData);
- // Select example rows to display.
- predictions.select("predictedLabel", "label", "features").show(5);
- // Select (prediction, true label) and compute test error
- MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("indexedLabel")
- .setPredictionCol("prediction")
- .setMetricName("precision");
- double accuracy = evaluator.evaluate(predictions);
- System.out.println("Test Error = " + (1.0 - accuracy));
- DecisionTreeClassificationModel treeModel =
- (DecisionTreeClassificationModel) (model.stages()[2]);
- System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
- // $example off$
- }
- }
复制代码
扫码加我 拉你入群
请注明:姓名-公司-职位
以便审核进群资格,未注明则拒绝
|
|
|