下面只对该算法做一个简单的介绍:
1、原理:待预测类别与训练集中距离最近的K个类别总数最多的那个类别被认为是预测元组的类别
下面的这个kNN1函数对一个待预测元组进行分类预测
本帖隐藏的内容
- function [ maxClass ] = kNN1( toPredictionVector,trainingSet,classVector,k )
- %% ******产品信息****************************
- % ------DateTime:2015-3-7-----------------------
- %% ******函数说明***************************
- % ------toPredictionVector是一个1×n的向量-------
- % ------trainingSet是一个m×n的训练集-------
- % ------classVector是一个n×1的类别向量----------
- %%
- % 归一化处理
- dataToOne=[toPredictionVector;trainingSet];
- newData=toOne(dataToOne);
- toPredicionVector_new=newData(1,:);
- trainingSet_new=newData(2:end,:);
- % 计算向量toPredicionVector_new与矩阵trainingSet_new中各行向量的欧式距离
- Distance=pdist2(trainingSet_new,toPredicionVector_new);
- % 对距离进行升序排序,并获取原索引值
- [~,index]=sort(Distance,'ascend');
- sortClass=cell(k,1);
- % 获取排序后的类别
- for i=1:k
- sortClass(i)=classVector(index(i));
- end
- % 统计各个类别的数量
- stateClass=tabulate(sortClass);
- Value=stateClass(:,1);%类别名称
- Count=stateClass(:,2);%类别数量
- % 获取最大数量类别的名称
- [~,maxIndex]=max(cell2mat(Count));
- maxClass=Value(maxIndex);
- end
2、其它函数:toOne()函数对数据归一化处理,kNN()函数对多个元组进行分类预测
- function [ newDataSet ] = toOne( dataSet )
- %% 对数据集归一化处理
- % ------DateTime:2015-3-7-----------------------
- % ------用于对数据集归一化处理,划为0-1之间的数----
- %%
- minVector=min(dataSet);
- maxVector=max(dataSet);
- min_max_D=maxVector-minVector;
- dataSetSize=size(dataSet,1);
- min_max_D_M=repmat(min_max_D,dataSetSize,1);
- minDataSet=repmat(minVector,dataSetSize,1);
- newDataSet=(dataSet-minDataSet)./min_max_D_M;
- end
- function [ maxClass ] = kNN( toPredictionSet,trainingSet,classVector,k )
- %% ******产品信息****************************
- % ------DateTime:2015-3-7-----------------------
- %% ******函数说明***************************
- % ------toPredictionSet是一个m1×n的向量-------
- % ------trainingSet是一个m2×n的训练集矩阵-------
- % ------classVector是一个n×1的类别向量----------
- %%
- predicionSize=size(toPredictionSet,1);
- maxClass=cell(predicionSize,1);
- for i=1:predicionSize
- maxClass(i)=kNN1(toPredictionSet(i,:),trainingSet,classVector,k);
- end
- end
3、运行脚本:
- clc,clear
- [trainingData,trainingClass]=xlsread('trainingData.xlsx');
- [testData,testClass]=xlsread('toPredictionData.xlsx');
- toPredictionData=xlsread('toPredictionData.xlsx');
- % 归一化处理
- % 绘制分类图
- pSize=size(toPredictionData,1);
- dataToOne=[toPredictionData;trainingData];
- newData=toOne(dataToOne);
- toPredicionSet_new=newData(1:pSize,:);
- trainingSet_new=newData(pSize+1:end,:);
- H1=plot(trainingSet_new(:,2),trainingSet_new(:,1),'o');
- hold on
- H2=plot(toPredicionSet_new(:,2),toPredicionSet_new(:,1),'.');
- title('kNN算法实现偏好分类', 'FontWeight','Bold', 'FontSize', 15);
- M = {'训练集';'预测集'};
- legend([H1,H2],M);
- predictionClass_Test=kNN(testData,trainingData,trainingClass,20);
- tSize=size(testData,1);
- T_count=0;
- for i=1:tSize
- A=predictionClass_Test{i};
- B=testClass{i};
- if size(A,2)==size(B,2)
- C=A==B;
- C=C+0;
- if sum(C)==size(A,2)
- T_count=T_count+1;
- end
- end
- end
- r=T_count/tSize;
- if r>=0.90
- predictionClass=kNN(toPredictionData,trainingData,trainingClass,20);
- end
写的不是很好,欢迎大家讨论,一起进步!
量化投资板块支持C/C++/C#/Java/Matlab/R/Splus/Python/VBA/Perl/PHP/JavaScript/.NET/MySQL/SqlServer/Oracle等各类编程原创,较大奖励加分,好贴直接精华!