楼主: Lisrelchen
1713 8

[Case Study]Bayesian Robust Regression in PyMC3 [推广有奖]

  • 0关注
  • 62粉丝

VIP

已卖:4194份资源

院士

67%

还不是VIP/贵宾

-

TA的文库  其他...

Bayesian NewOccidental

Spatial Data Analysis

东西方数据挖掘

威望
0
论坛币
50288 个
通用积分
83.6906
学术水平
253 点
热心指数
300 点
信用等级
208 点
经验
41518 点
帖子
3256
精华
14
在线时间
766 小时
注册时间
2006-5-4
最后登录
2022-11-6

楼主
Lisrelchen 发表于 2016-12-18 07:03:51 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

Author: Thomas Wiecki


In this blog post I will write about:

  • How a few outliers can largely affect the fit of linear regression models.
  • How replacing the normal likelihood with Student T distribution produces robust regression.
  • How this can easily be done with PyMC3 and its new glm module by passing a family object.

本帖隐藏的内容

Bayesian Robust Regression in PyMC3.pdf (439.23 KB)





二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:regression Case study regressio Bayesian regress produces passing family affect easily

沙发
Lisrelchen 发表于 2016-12-18 07:04:35
  1. %matplotlib inline

  2. import pymc3 as pm

  3. import matplotlib.pyplot as plt
  4. import numpy as np

  5. import theano
复制代码

藤椅
Lisrelchen 发表于 2016-12-18 07:04:54
  1. size = 100
  2. true_intercept = 1
  3. true_slope = 2

  4. x = np.linspace(0, 1, size)
  5. # y = a + b*x
  6. true_regression_line = true_intercept + true_slope * x
  7. # add noise
  8. y = true_regression_line + np.random.normal(scale=.5, size=size)

  9. # Add outliers
  10. x_out = np.append(x, [.1, .15, .2])
  11. y_out = np.append(y, [8, 6, 9])

  12. data = dict(x=x_out, y=y_out)
复制代码

板凳
Lisrelchen 发表于 2016-12-18 07:05:28
  1. fig = plt.figure(figsize=(7, 7))
  2. ax = fig.add_subplot(111, xlabel='x', ylabel='y', title='Generated data and underlying model')
  3. ax.plot(x_out, y_out, 'x', label='sampled data')
  4. ax.plot(x, true_regression_line, label='true regression line', lw=2.)
  5. plt.legend(loc=0);
复制代码

报纸
Lisrelchen 发表于 2016-12-18 07:06:03
  1. with pm.Model() as model:
  2.     pm.glm.glm('y ~ x', data)
  3.     start = pm.find_MAP()
  4.     step = pm.NUTS(scaling=start)
  5.     trace = pm.sample(2000, step, progressbar=False)
复制代码

地板
Lisrelchen 发表于 2016-12-18 07:06:30
  1. plt.subplot(111, xlabel='x', ylabel='y',
  2.             title='Posterior predictive regression lines')
  3. plt.plot(x_out, y_out, 'x', label='data')
  4. pm.glm.plot_posterior_predictive(trace, samples=100,
  5.                                  label='posterior predictive regression lines')
  6. plt.plot(x, true_regression_line,
  7.          label='true regression line', lw=3., c='y')

  8. plt.legend(loc=0);
复制代码

7
Lisrelchen 发表于 2016-12-18 07:06:57
  1. Lets look at those two distributions to get a feel for them.

  2. In [7]:
  3. normal_dist = pm.Normal.dist(mu=0, sd=1)
  4. t_dist = pm.StudentT.dist(mu=0, lam=1, nu=1)
  5. x_eval = np.linspace(-8, 8, 300)
  6. plt.plot(x_eval, theano.tensor.exp(normal_dist.logp(x_eval)).eval(), label='Normal', lw=2.)
  7. plt.plot(x_eval, theano.tensor.exp(t_dist.logp(x_eval)).eval(), label='Student T', lw=2.)
  8. plt.xlabel('x')
  9. plt.ylabel('Probability density')
  10. plt.legend();
复制代码

8
Lisrelchen 发表于 2016-12-18 07:07:33
  1. To define the usage of a T distribution in PyMC3 we can pass a family object -- T -- that specifies that our data is Student T-distributed (see glm.families for more choices). Note that this is the same syntax as R and statsmodels use.

  2. In [8]:
  3. with pm.Model() as model_robust:
  4.     family = pm.glm.families.T()
  5.     pm.glm.glm('y ~ x', data, family=family)
  6.     start = pm.find_MAP()
  7.     step = pm.NUTS(scaling=start)
  8.     trace_robust = pm.sample(2000, step, progressbar=False)

  9. plt.figure(figsize=(5, 5))
  10. plt.plot(x_out, y_out, 'x')
  11. pm.glm.plot_posterior_predictive(trace_robust,
  12.                                  label='posterior predictive regression lines')
  13. plt.plot(x, true_regression_line,
  14.          label='true regression line', lw=3., c='y')
  15. plt.legend()
复制代码

9
franky_sas 发表于 2016-12-18 11:30:30

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注jltj
拉您入交流群
GMT+8, 2026-1-24 06:40