239 0

[编程问题求助] 复现10-fold cross validation of KNN 失败,请教下是什么原因 [推广有奖]

  • 0关注
  • 0粉丝

大专生

78%

还不是VIP/贵宾

-

威望
0
论坛币
2421 个
通用积分
0.0000
学术水平
0 点
热心指数
0 点
信用等级
0 点
经验
189 点
帖子
6
精华
0
在线时间
118 小时
注册时间
2016-9-20
最后登录
2024-4-8

相似文件 换一批

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币
求助一下,作业让手撕一下10-fold cross validation of KNN的实现,编程思路非常直接,在随机分组后按照:
选定特定样本组号-剔除组号内样本-按照x排序取临近k位x对应y的均值做预测值-计算CVErrk然后选定optimal k的思路做的,命令如下。

问题是算出来的结果非常非常的奇怪,随着k的波动(横轴)CVErr完全不是预期的情况,应该是代码有错?实在是看不出来问题,想请大神诊断一下!十分感谢
  1. di "Question2:K-nearest-neighbors regression with k=100"

  2.         gen predict_y_k100=.

  3.         sort x
  4.         drop x_sort
  5.         gen x_sort=_n
  6.         forvalue i=1/100 {
  7.                 preserve
  8.                 keep if id==`i'
  9.                 local t=x_sort
  10.                 restore
  11.                 preserve
  12.                 keep if x_sort<=`t'+100&x_sort>=`t'-100 //k=100情况下,仅保留相近距离为1的三项
  13.                 egen m_y= mean(y)                                                //计算此时y的样本均值
  14.                 local tt=m_y                                                        //提取预测结果       
  15.                 restore
  16.                 replace predict_y=`tt' if id==`i'                //将所得的预测值写入对应预测结果栏
  17.                 }

  18.         twoway(scatter y x)(line predict_y x),legend(order(1 "观测值y" 2 "拟合值y,k=100"))
  19.         save question2.gph,replace

  20.         di"Question3:10-fold cross validation"
  21.                
  22.                 //randomly dividing 10 fold by generating random number
  23.                 gen ttt=runiform(0,1)*10
  24.                 sort ttt
  25.                 gen n=_n
  26.                 gen rc=.                                                                        //分组变量rc
  27.                 forvalue i=0(1)9 {
  28.                         replace rc=`i'+1 if n>=`i'*10&n<=(`i'+1)*10&n!=(`i'+1)*10
  29.                         }                                                                                //基于均匀分布随机数的随机分组
  30.                 drop ttt n
  31.                
  32.                 //estiamte the 10-fold cross knn
  33.                 gen pred_y_cross=.
  34.                 gen k=.
  35.                 forvalue k=1/100 {
  36.                         forvalue i=1/100{
  37.                                 preserve
  38.                                         keep if id==`i'
  39.                                         local t=rc
  40.                                 restore                                                                //提取样本点所在分组
  41.                                 preserve
  42.                                         drop if rc==`t'                                        //剔除样本点所在分组数据后,提取X序列位置
  43.                                         sort x
  44.                                         gen ord_x=_n
  45.                                         keep if id==`i'
  46.                                         local tt=ord_x
  47.                                 restore
  48.                                 preserve
  49.                                         drop if rc==`t'                                        //剔除样本点所在分组数据后,提取X序列位置
  50.                                         sort x
  51.                                         gen ord_x=_n
  52.                                         keep if ord_x<=`tt'+`k'&ord_x>=`tt'-`k'
  53.                                         egen m_y=mean(y)
  54.                                         local ttt_`i'=m_y                                //记录下给定k距离下序号i样本的预测值
  55.                                 restore
  56.                                 replace pred_y_cross=`ttt_`i'' if id==`i'
  57.                                 }
  58.                         forvalue i=1/10 {
  59.                                 preserve
  60.                                         keep if rc==`i'
  61.                                         gen d2_error_y=(y-pred_y_cross)^2
  62.                                         egen cve_y=mean(d2_error_y)
  63.                                         local cve_y`i'=cve_y
  64.                                 restore
  65.                                 }       
  66.                         dis
  67.                         replace k=(`cve_y1'+`cve_y2'+`cve_y3'+`cve_y4'+`cve_y5'+`cve_y6'+`cve_y7'+`cve_y8'+`cve_y9'+`cve_y10')/10 if id==`k'
  68.                         }
复制代码




二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:Validation ATION 是什么原因 Cross Valid

屏幕截图 2023-03-22 212312.png (88.98 KB)

屏幕截图 2023-03-22 212312.png

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注jltj
拉您入交流群

京ICP备16021002-2号 京B2-20170662号 京公网安备 11010802022788号 论坛法律顾问:王进律师 知识产权保护声明   免责及隐私声明

GMT+8, 2024-6-15 18:45