楼主: ReneeBK
1157 0

[Case Study]Random Forest using Scala [推广有奖]

  • 1关注
  • 62粉丝

VIP

学术权威

14%

还不是VIP/贵宾

-

TA的文库  其他...

R资源总汇

Panel Data Analysis

Experimental Design

威望
1
论坛币
49422 个
通用积分
52.2304
学术水平
370 点
热心指数
273 点
信用等级
335 点
经验
57815 点
帖子
4006
精华
21
在线时间
582 小时
注册时间
2005-5-8
最后登录
2023-11-26

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币
  1. // scalastyle:off println
  2. package org.apache.spark.examples.ml

  3. import scala.collection.mutable
  4. import scala.language.reflectiveCalls

  5. import scopt.OptionParser

  6. import org.apache.spark.{SparkConf, SparkContext}
  7. import org.apache.spark.examples.mllib.AbstractParams
  8. import org.apache.spark.ml.{Pipeline, PipelineStage}
  9. import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
  10. import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
  11. import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
  12. import org.apache.spark.sql.DataFrame


  13. /**
  14. * An example runner for decision trees. Run with
  15. * {{{
  16. * ./bin/run-example ml.RandomForestExample [options]
  17. * }}}
  18. * Decision Trees and ensembles can take a large amount of memory.  If the run-example command
  19. * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
  20. * For local mode, run
  21. * {{{
  22. * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g
  23. *   [examples JAR path] [options]
  24. * }}}
  25. * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
  26. */
  27. object RandomForestExample {

  28.   case class Params(
  29.       input: String = null,
  30.       testInput: String = "",
  31.       dataFormat: String = "libsvm",
  32.       algo: String = "classification",
  33.       maxDepth: Int = 5,
  34.       maxBins: Int = 32,
  35.       minInstancesPerNode: Int = 1,
  36.       minInfoGain: Double = 0.0,
  37.       numTrees: Int = 10,
  38.       featureSubsetStrategy: String = "auto",
  39.       fracTest: Double = 0.2,
  40.       cacheNodeIds: Boolean = false,
  41.       checkpointDir: Option[String] = None,
  42.       checkpointInterval: Int = 10) extends AbstractParams[Params]

  43.   def main(args: Array[String]) {
  44.     val defaultParams = Params()

  45.     val parser = new OptionParser[Params]("RandomForestExample") {
  46.       head("RandomForestExample: an example random forest app.")
  47.       opt[String]("algo")
  48.         .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
  49.         .action((x, c) => c.copy(algo = x))
  50.       opt[Int]("maxDepth")
  51.         .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
  52.         .action((x, c) => c.copy(maxDepth = x))
  53.       opt[Int]("maxBins")
  54.         .text(s"max number of bins, default: ${defaultParams.maxBins}")
  55.         .action((x, c) => c.copy(maxBins = x))
  56.       opt[Int]("minInstancesPerNode")
  57.         .text(s"min number of instances required at child nodes to create the parent split," +
  58.         s" default: ${defaultParams.minInstancesPerNode}")
  59.         .action((x, c) => c.copy(minInstancesPerNode = x))
  60.       opt[Double]("minInfoGain")
  61.         .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
  62.         .action((x, c) => c.copy(minInfoGain = x))
  63.       opt[Int]("numTrees")
  64.         .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}")
  65.         .action((x, c) => c.copy(numTrees = x))
  66.       opt[String]("featureSubsetStrategy")
  67.         .text(s"number of features to use per node (supported:" +
  68.         s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," +
  69.         s" default: ${defaultParams.numTrees}")
  70.         .action((x, c) => c.copy(featureSubsetStrategy = x))
  71.       opt[Double]("fracTest")
  72.         .text(s"fraction of data to hold out for testing.  If given option testInput, " +
  73.         s"this option is ignored. default: ${defaultParams.fracTest}")
  74.         .action((x, c) => c.copy(fracTest = x))
  75.       opt[Boolean]("cacheNodeIds")
  76.         .text(s"whether to use node Id cache during training, " +
  77.         s"default: ${defaultParams.cacheNodeIds}")
  78.         .action((x, c) => c.copy(cacheNodeIds = x))
  79.       opt[String]("checkpointDir")
  80.         .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
  81.         s"default: ${
  82.           defaultParams.checkpointDir match {
  83.             case Some(strVal) => strVal
  84.             case None => "None"
  85.           }
  86.         }")
  87.         .action((x, c) => c.copy(checkpointDir = Some(x)))
  88.       opt[Int]("checkpointInterval")
  89.         .text(s"how often to checkpoint the node Id cache, " +
  90.         s"default: ${defaultParams.checkpointInterval}")
  91.         .action((x, c) => c.copy(checkpointInterval = x))
  92.       opt[String]("testInput")
  93.         .text(s"input path to test dataset.  If given, option fracTest is ignored." +
  94.         s" default: ${defaultParams.testInput}")
  95.         .action((x, c) => c.copy(testInput = x))
  96.       opt[String]("dataFormat")
  97.         .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
  98.         .action((x, c) => c.copy(dataFormat = x))
  99.       arg[String]("<input>")
  100.         .text("input path to labeled examples")
  101.         .required()
  102.         .action((x, c) => c.copy(input = x))
  103.       checkConfig { params =>
  104.         if (params.fracTest < 0 || params.fracTest >= 1) {
  105.           failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
  106.         } else {
  107.           success
  108.         }
  109.       }
  110.     }

  111.     parser.parse(args, defaultParams).map { params =>
  112.       run(params)
  113.     }.getOrElse {
  114.       sys.exit(1)
  115.     }
  116.   }

  117.   def run(params: Params) {
  118.     val conf = new SparkConf().setAppName(s"RandomForestExample with $params")
  119.     val sc = new SparkContext(conf)
  120.     params.checkpointDir.foreach(sc.setCheckpointDir)
  121.     val algo = params.algo.toLowerCase

  122.     println(s"RandomForestExample with parameters:\n$params")

  123.     // Load training and test data and cache it.
  124.     val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
  125.       params.dataFormat, params.testInput, algo, params.fracTest)

  126.     // Set up Pipeline
  127.     val stages = new mutable.ArrayBuffer[PipelineStage]()
  128.     // (1) For classification, re-index classes.
  129.     val labelColName = if (algo == "classification") "indexedLabel" else "label"
  130.     if (algo == "classification") {
  131.       val labelIndexer = new StringIndexer()
  132.         .setInputCol("label")
  133.         .setOutputCol(labelColName)
  134.       stages += labelIndexer
  135.     }
  136.     // (2) Identify categorical features using VectorIndexer.
  137.     //     Features with more than maxCategories values will be treated as continuous.
  138.     val featuresIndexer = new VectorIndexer()
  139.       .setInputCol("features")
  140.       .setOutputCol("indexedFeatures")
  141.       .setMaxCategories(10)
  142.     stages += featuresIndexer
  143.     // (3) Learn Random Forest
  144.     val dt = algo match {
  145.       case "classification" =>
  146.         new RandomForestClassifier()
  147.           .setFeaturesCol("indexedFeatures")
  148.           .setLabelCol(labelColName)
  149.           .setMaxDepth(params.maxDepth)
  150.           .setMaxBins(params.maxBins)
  151.           .setMinInstancesPerNode(params.minInstancesPerNode)
  152.           .setMinInfoGain(params.minInfoGain)
  153.           .setCacheNodeIds(params.cacheNodeIds)
  154.           .setCheckpointInterval(params.checkpointInterval)
  155.           .setFeatureSubsetStrategy(params.featureSubsetStrategy)
  156.           .setNumTrees(params.numTrees)
  157.       case "regression" =>
  158.         new RandomForestRegressor()
  159.           .setFeaturesCol("indexedFeatures")
  160.           .setLabelCol(labelColName)
  161.           .setMaxDepth(params.maxDepth)
  162.           .setMaxBins(params.maxBins)
  163.           .setMinInstancesPerNode(params.minInstancesPerNode)
  164.           .setMinInfoGain(params.minInfoGain)
  165.           .setCacheNodeIds(params.cacheNodeIds)
  166.           .setCheckpointInterval(params.checkpointInterval)
  167.           .setFeatureSubsetStrategy(params.featureSubsetStrategy)
  168.           .setNumTrees(params.numTrees)
  169.       case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
  170.     }
  171.     stages += dt
  172.     val pipeline = new Pipeline().setStages(stages.toArray)

  173.     // Fit the Pipeline
  174.     val startTime = System.nanoTime()
  175.     val pipelineModel = pipeline.fit(training)
  176.     val elapsedTime = (System.nanoTime() - startTime) / 1e9
  177.     println(s"Training time: $elapsedTime seconds")

  178.     // Get the trained Random Forest from the fitted PipelineModel
  179.     algo match {
  180.       case "classification" =>
  181.         val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
  182.         if (rfModel.totalNumNodes < 30) {
  183.           println(rfModel.toDebugString) // Print full model.
  184.         } else {
  185.           println(rfModel) // Print model summary.
  186.         }
  187.       case "regression" =>
  188.         val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
  189.         if (rfModel.totalNumNodes < 30) {
  190.           println(rfModel.toDebugString) // Print full model.
  191.         } else {
  192.           println(rfModel) // Print model summary.
  193.         }
  194.       case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
  195.     }

  196.     // Evaluate model on training, test data
  197.     algo match {
  198.       case "classification" =>
  199.         println("Training data results:")
  200.         DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
  201.         println("Test data results:")
  202.         DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
  203.       case "regression" =>
  204.         println("Training data results:")
  205.         DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
  206.         println("Test data results:")
  207.         DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
  208.       case _ =>
  209.         throw new IllegalArgumentException("Algo ${params.algo} not supported.")
  210.     }

  211.     sc.stop()
  212.   }
  213. }
  214. // scalastyle:on println
复制代码


二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:Case study Forest random study Using Random import

本帖被以下文库推荐

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加JingGuanBbs
拉您进交流群

京ICP备16021002-2号 京B2-20170662号 京公网安备 11010802022788号 论坛法律顾问:王进律师 知识产权保护声明   免责及隐私声明

GMT+8, 2024-5-29 11:19