楼主: Lisrelchen
1974 12

【Apache Spark】Spark MLLib - Predict Store Sales with ML Pipelines [推广有奖]

  • 0关注
  • 62粉丝

VIP

已卖:4194份资源

院士

67%

还不是VIP/贵宾

-

TA的文库  其他...

Bayesian NewOccidental

Spatial Data Analysis

东西方数据挖掘

威望
0
论坛币
50288 个
通用积分
83.6306
学术水平
253 点
热心指数
300 点
信用等级
208 点
经验
41518 点
帖子
3256
精华
14
在线时间
766 小时
注册时间
2006-5-4
最后登录
2022-11-6

楼主
Lisrelchen 发表于 2017-4-30 09:38:35 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

本帖隐藏的内容

Spark MLLib - Predict Store Sales with ML Pipelines.pdf (474.22 KB)

  1. Overview
  2. Recently I had to work on a Machine Learning problem for class and found a good opportunity for a Spark Tutorial. Using store sales from Rossmann found on kaggle, we are going to set up a machine learning pipeline to cover everything from the preprocessing all the way to making and saving predictions. In a sense, this is a "full-stack" machine learning project that's probably fairly similar to something we might do in the real world. Spark's ML Pipelines API is going to make it very easy for us to do this. You can follow along with the code on my github or below.
复制代码


二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:Apache Spark Pipelines pipeline predict apache

沙发
Lisrelchen 发表于 2017-4-30 09:40:04
  1. Linear Regression
  2. The pipeline we've set up is pretty self explanatory. Scaladef preppedLRPipeline():TrainValidationSplit = {
  3.   val lr = new LinearRegression()

  4.   val paramGrid = new ParamGridBuilder()
  5.     .addGrid(lr.regParam, Array(0.1, 0.01))
  6.     .addGrid(lr.fitIntercept)
  7.     .addGrid(lr.elasticNetParam, Array(0.0, 0.25, 0.5, 0.75, 1.0))
  8.     .build()

  9.   val pipeline = new Pipeline()
  10.     .setStages(Array(stateHolidayIndexer, schoolHolidayIndexer,
  11.       stateHolidayEncoder, schoolHolidayEncoder, storeEncoder,
  12.       dayOfWeekEncoder, dayOfMonthEncoder,
  13.       assembler, lr))

  14.   val tvs = new TrainValidationSplit()
  15.     .setEstimator(pipeline)
  16.     .setEvaluator(new RegressionEvaluator)
  17.     .setEstimatorParamMaps(paramGrid)
  18.     .setTrainRatio(0.75)
  19.   tvs
  20. }
复制代码

藤椅
Lisrelchen 发表于 2017-4-30 09:40:37

Data

  1. def loadTrainingData(sqlContext:HiveContext):DataFrame = {
  2.   val trainRaw = sqlContext
  3.     .read.format("com.databricks.spark.csv")
  4.     .option("header", "true")
  5.     .load("../mlproject/rossman/train.csv")
  6.     .repartition(6)
  7.   trainRaw.registerTempTable("raw_training_data")

  8.   sqlContext.sql("""SELECT
  9.     double(Sales) label, double(Store) Store, int(Open) Open, double(DayOfWeek) DayOfWeek,
  10.     StateHoliday, SchoolHoliday, (double(regexp_extract(Date, '\\d+-\\d+-(\\d+)', 1))) DayOfMonth
  11.     FROM raw_training_data
  12.   """).na.drop()
  13. }

  14. def loadKaggleTestData(sqlContext:HiveContext) = {
  15.   val testRaw = sqlContext
  16.     .read.format("com.databricks.spark.csv")
  17.     .option("header", "true")
  18.     .load("../mlproject/rossman/test.csv")
  19.     .repartition(6)
  20.   testRaw.registerTempTable("raw_test_data")

  21.   val testData = sqlContext.sql("""SELECT
  22.     Id, double(Store) Store, int(Open) Open, double(DayOfWeek) DayOfWeek, StateHoliday,
  23.     SchoolHoliday, (double(regexp_extract(Date, '\\d+-\\d+-(\\d+)', 1))) DayOfMonth
  24.     FROM raw_test_data
  25.     WHERE !(ISNULL(Id) OR ISNULL(Store) OR ISNULL(Open) OR ISNULL(DayOfWeek)
  26.       OR ISNULL(StateHoliday) OR ISNULL(SchoolHoliday))
  27.   """).na.drop() // weird things happen if you don't filter out the null values manually

  28.   Array(testRaw, testData) // got to hold onto testRaw so we can make sure
  29.   // to have all the prediction IDs to submit to kaggle
  30. }
复制代码

板凳
Lisrelchen 发表于 2017-4-30 09:41:07

Save

  1. def savePredictions(predictions:DataFrame, testRaw:DataFrame) = {
  2.   val tdOut = testRaw
  3.     .select("Id")
  4.     .distinct()
  5.     .join(predictions, testRaw("Id") === predictions("PredId"), "outer")
  6.     .select("Id", "Sales")
  7.     .na.fill(0:Double) // some of our inputs were null so we have to
  8.                        // fill these with something
  9.   tdOut
  10.     .coalesce(1)
  11.     .write.format("com.databricks.spark.csv")
  12.     .option("header", "true")
  13.     .save("linear_regression_predictions.csv")
  14. }
复制代码

报纸
Lisrelchen 发表于 2017-4-30 09:42:13

Fit

  1. def fitModel(tvs:TrainValidationSplit, data:DataFrame) = {
  2.   val Array(training, test) = data.randomSplit(Array(0.8, 0.2), seed = 12345)
  3.   logger.info("Fitting data")
  4.   val model = tvs.fit(training)
  5.   logger.info("Now performing test on hold out set")
  6.   val holdout = model.transform(test).select("prediction","label")

  7.   // have to do a type conversion for RegressionMetrics
  8.   val rm = new RegressionMetrics(holdout.rdd.map(x =>
  9.     (x(0).asInstanceOf[Double], x(1).asInstanceOf[Double])))

  10.   logger.info("Test Metrics")
  11.   logger.info("Test Explained Variance:")
  12.   logger.info(rm.explainedVariance)
  13.   logger.info("Test R^2 Coef:")
  14.   logger.info(rm.r2)
  15.   logger.info("Test MSE:")
  16.   logger.info(rm.meanSquaredError)
  17.   logger.info("Test RMSE:")
  18.   logger.info(rm.rootMeanSquaredError)

  19.   model
  20. }
复制代码

地板
Lisrelchen 发表于 2017-4-30 09:43:05

The linear Regression Pipeline

  1. val data = loadTrainingData(sqlContext)
  2. val Array(testRaw, testData) = loadKaggleTestData(sqlContext)

  3. // The linear Regression Pipeline
  4. val linearTvs = preppedLRPipeline()
  5. logger.info("evaluating linear regression")
  6. val lrModel = fitModel(linearTvs, data)
  7. logger.info("Generating kaggle predictions")
  8. val lrOut = lrModel.transform(testData)
  9.   .withColumnRenamed("prediction","Sales")
  10.   .withColumnRenamed("Id","PredId")
  11.   .select("PredId", "Sales")
  12. savePredictions(lrOut, testRaw)
复制代码

7
franky_sas 发表于 2017-4-30 10:37:18
多谢分享!

8
colongkong 发表于 2017-4-30 11:13:34
非常感谢分享!

9
MouJack007 发表于 2017-4-30 11:40:22
谢谢楼主分享!

10
MouJack007 发表于 2017-4-30 11:41:28

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注jltj
拉您入交流群
GMT+8, 2026-1-1 17:29