- 阅读权限
- 255
- 威望
- 1 级
- 论坛币
- 49422 个
- 通用积分
- 52.2304
- 学术水平
- 370 点
- 热心指数
- 273 点
- 信用等级
- 335 点
- 经验
- 57815 点
- 帖子
- 4006
- 精华
- 21
- 在线时间
- 582 小时
- 注册时间
- 2005-5-8
- 最后登录
- 2023-11-26
学术权威
还不是VIP/贵宾
TA的文库 其他... R资源总汇
Panel Data Analysis
Experimental Design
- 威望
- 1 级
- 论坛币
- 49422 个
- 通用积分
- 52.2304
- 学术水平
- 370 点
- 热心指数
- 273 点
- 信用等级
- 335 点
- 经验
- 57815 点
- 帖子
- 4006
- 精华
- 21
- 在线时间
- 582 小时
- 注册时间
- 2005-5-8
- 最后登录
- 2023-11-26
| 开心 2017-10-21 10:25:33 |
---|
签到天数: 1 天 连续签到: 1 天 [LV.1]初来乍到
|
经管之家送您一份
应届毕业生专属福利!
求职就业群
感谢您参与论坛问题回答
经管之家送您两个论坛币!
+2 论坛币
- // scalastyle:off println
- package org.apache.spark.examples.ml
- import scala.collection.mutable
- import scala.language.reflectiveCalls
- import scopt.OptionParser
- import org.apache.spark.{SparkConf, SparkContext}
- import org.apache.spark.examples.mllib.AbstractParams
- import org.apache.spark.ml.{Pipeline, PipelineStage}
- import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
- import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
- import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
- import org.apache.spark.sql.DataFrame
- /**
- * An example runner for decision trees. Run with
- * {{{
- * ./bin/run-example ml.RandomForestExample [options]
- * }}}
- * Decision Trees and ensembles can take a large amount of memory. If the run-example command
- * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
- * For local mode, run
- * {{{
- * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g
- * [examples JAR path] [options]
- * }}}
- * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
- */
- object RandomForestExample {
- case class Params(
- input: String = null,
- testInput: String = "",
- dataFormat: String = "libsvm",
- algo: String = "classification",
- maxDepth: Int = 5,
- maxBins: Int = 32,
- minInstancesPerNode: Int = 1,
- minInfoGain: Double = 0.0,
- numTrees: Int = 10,
- featureSubsetStrategy: String = "auto",
- fracTest: Double = 0.2,
- cacheNodeIds: Boolean = false,
- checkpointDir: Option[String] = None,
- checkpointInterval: Int = 10) extends AbstractParams[Params]
- def main(args: Array[String]) {
- val defaultParams = Params()
- val parser = new OptionParser[Params]("RandomForestExample") {
- head("RandomForestExample: an example random forest app.")
- opt[String]("algo")
- .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
- .action((x, c) => c.copy(algo = x))
- opt[Int]("maxDepth")
- .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
- .action((x, c) => c.copy(maxDepth = x))
- opt[Int]("maxBins")
- .text(s"max number of bins, default: ${defaultParams.maxBins}")
- .action((x, c) => c.copy(maxBins = x))
- opt[Int]("minInstancesPerNode")
- .text(s"min number of instances required at child nodes to create the parent split," +
- s" default: ${defaultParams.minInstancesPerNode}")
- .action((x, c) => c.copy(minInstancesPerNode = x))
- opt[Double]("minInfoGain")
- .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
- .action((x, c) => c.copy(minInfoGain = x))
- opt[Int]("numTrees")
- .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}")
- .action((x, c) => c.copy(numTrees = x))
- opt[String]("featureSubsetStrategy")
- .text(s"number of features to use per node (supported:" +
- s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," +
- s" default: ${defaultParams.numTrees}")
- .action((x, c) => c.copy(featureSubsetStrategy = x))
- opt[Double]("fracTest")
- .text(s"fraction of data to hold out for testing. If given option testInput, " +
- s"this option is ignored. default: ${defaultParams.fracTest}")
- .action((x, c) => c.copy(fracTest = x))
- opt[Boolean]("cacheNodeIds")
- .text(s"whether to use node Id cache during training, " +
- s"default: ${defaultParams.cacheNodeIds}")
- .action((x, c) => c.copy(cacheNodeIds = x))
- opt[String]("checkpointDir")
- .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
- s"default: ${
- defaultParams.checkpointDir match {
- case Some(strVal) => strVal
- case None => "None"
- }
- }")
- .action((x, c) => c.copy(checkpointDir = Some(x)))
- opt[Int]("checkpointInterval")
- .text(s"how often to checkpoint the node Id cache, " +
- s"default: ${defaultParams.checkpointInterval}")
- .action((x, c) => c.copy(checkpointInterval = x))
- opt[String]("testInput")
- .text(s"input path to test dataset. If given, option fracTest is ignored." +
- s" default: ${defaultParams.testInput}")
- .action((x, c) => c.copy(testInput = x))
- opt[String]("dataFormat")
- .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
- .action((x, c) => c.copy(dataFormat = x))
- arg[String]("<input>")
- .text("input path to labeled examples")
- .required()
- .action((x, c) => c.copy(input = x))
- checkConfig { params =>
- if (params.fracTest < 0 || params.fracTest >= 1) {
- failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
- } else {
- success
- }
- }
- }
- parser.parse(args, defaultParams).map { params =>
- run(params)
- }.getOrElse {
- sys.exit(1)
- }
- }
- def run(params: Params) {
- val conf = new SparkConf().setAppName(s"RandomForestExample with $params")
- val sc = new SparkContext(conf)
- params.checkpointDir.foreach(sc.setCheckpointDir)
- val algo = params.algo.toLowerCase
- println(s"RandomForestExample with parameters:\n$params")
- // Load training and test data and cache it.
- val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
- params.dataFormat, params.testInput, algo, params.fracTest)
- // Set up Pipeline
- val stages = new mutable.ArrayBuffer[PipelineStage]()
- // (1) For classification, re-index classes.
- val labelColName = if (algo == "classification") "indexedLabel" else "label"
- if (algo == "classification") {
- val labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol(labelColName)
- stages += labelIndexer
- }
- // (2) Identify categorical features using VectorIndexer.
- // Features with more than maxCategories values will be treated as continuous.
- val featuresIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(10)
- stages += featuresIndexer
- // (3) Learn Random Forest
- val dt = algo match {
- case "classification" =>
- new RandomForestClassifier()
- .setFeaturesCol("indexedFeatures")
- .setLabelCol(labelColName)
- .setMaxDepth(params.maxDepth)
- .setMaxBins(params.maxBins)
- .setMinInstancesPerNode(params.minInstancesPerNode)
- .setMinInfoGain(params.minInfoGain)
- .setCacheNodeIds(params.cacheNodeIds)
- .setCheckpointInterval(params.checkpointInterval)
- .setFeatureSubsetStrategy(params.featureSubsetStrategy)
- .setNumTrees(params.numTrees)
- case "regression" =>
- new RandomForestRegressor()
- .setFeaturesCol("indexedFeatures")
- .setLabelCol(labelColName)
- .setMaxDepth(params.maxDepth)
- .setMaxBins(params.maxBins)
- .setMinInstancesPerNode(params.minInstancesPerNode)
- .setMinInfoGain(params.minInfoGain)
- .setCacheNodeIds(params.cacheNodeIds)
- .setCheckpointInterval(params.checkpointInterval)
- .setFeatureSubsetStrategy(params.featureSubsetStrategy)
- .setNumTrees(params.numTrees)
- case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
- }
- stages += dt
- val pipeline = new Pipeline().setStages(stages.toArray)
- // Fit the Pipeline
- val startTime = System.nanoTime()
- val pipelineModel = pipeline.fit(training)
- val elapsedTime = (System.nanoTime() - startTime) / 1e9
- println(s"Training time: $elapsedTime seconds")
- // Get the trained Random Forest from the fitted PipelineModel
- algo match {
- case "classification" =>
- val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
- if (rfModel.totalNumNodes < 30) {
- println(rfModel.toDebugString) // Print full model.
- } else {
- println(rfModel) // Print model summary.
- }
- case "regression" =>
- val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
- if (rfModel.totalNumNodes < 30) {
- println(rfModel.toDebugString) // Print full model.
- } else {
- println(rfModel) // Print model summary.
- }
- case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
- }
- // Evaluate model on training, test data
- algo match {
- case "classification" =>
- println("Training data results:")
- DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
- println("Test data results:")
- DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
- case "regression" =>
- println("Training data results:")
- DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
- println("Test data results:")
- DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
- case _ =>
- throw new IllegalArgumentException("Algo ${params.algo} not supported.")
- }
- sc.stop()
- }
- }
- // scalastyle:on println
复制代码
扫码加我 拉你入群
请注明:姓名-公司-职位
以便审核进群资格,未注明则拒绝
|
|
|