经管之家送您一份
应届毕业生专属福利!
求职就业群
感谢您参与论坛问题回答
经管之家送您两个论坛币!
+2 论坛币
- // scalastyle:off println
- package org.apache.spark.examples.ml
- import scala.beans.BeanInfo
- import org.apache.spark.{SparkConf, SparkContext}
- import org.apache.spark.ml.Pipeline
- import org.apache.spark.ml.classification.LogisticRegression
- import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
- import org.apache.spark.mllib.linalg.Vector
- import org.apache.spark.sql.{Row, SQLContext}
- @BeanInfo
- case class LabeledDocument(id: Long, text: String, label: Double)
- @BeanInfo
- case class Document(id: Long, text: String)
- /**
- * A simple text classification pipeline that recognizes "spark" from input text. This is to show
- * how to create and configure an ML pipeline. Run with
- * {{{
- * bin/run-example ml.SimpleTextClassificationPipeline
- * }}}
- */
- object SimpleTextClassificationPipeline {
- def main(args: Array[String]) {
- val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
- val sc = new SparkContext(conf)
- val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._
- // Prepare training documents, which are labeled.
- val training = sc.parallelize(Seq(
- LabeledDocument(0L, "a b c d e spark", 1.0),
- LabeledDocument(1L, "b d", 0.0),
- LabeledDocument(2L, "spark f g h", 1.0),
- LabeledDocument(3L, "hadoop mapreduce", 0.0)))
- // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
- val tokenizer = new Tokenizer()
- .setInputCol("text")
- .setOutputCol("words")
- val hashingTF = new HashingTF()
- .setNumFeatures(1000)
- .setInputCol(tokenizer.getOutputCol)
- .setOutputCol("features")
- val lr = new LogisticRegression()
- .setMaxIter(10)
- .setRegParam(0.001)
- val pipeline = new Pipeline()
- .setStages(Array(tokenizer, hashingTF, lr))
- // Fit the pipeline to training documents.
- val model = pipeline.fit(training.toDF())
- // Prepare test documents, which are unlabeled.
- val test = sc.parallelize(Seq(
- Document(4L, "spark i j k"),
- Document(5L, "l m n"),
- Document(6L, "spark hadoop spark"),
- Document(7L, "apache hadoop")))
- // Make predictions on test documents.
- model.transform(test.toDF())
- .select("id", "text", "probability", "prediction")
- .collect()
- .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
- println(s"($id, $text) --> prob=$prob, prediction=$prediction")
- }
- sc.stop()
- }
- }
- // scalastyle:on println
复制代码
扫码加我 拉你入群
请注明:姓名-公司-职位
以便审核进群资格,未注明则拒绝
|