- 阅读权限
- 255
- 威望
- 0 级
- 论坛币
- 50288 个
- 通用积分
- 83.6306
- 学术水平
- 253 点
- 热心指数
- 300 点
- 信用等级
- 208 点
- 经验
- 41518 点
- 帖子
- 3256
- 精华
- 14
- 在线时间
- 766 小时
- 注册时间
- 2006-5-4
- 最后登录
- 2022-11-6
|
- # -*- coding: utf-8 -*-
- """
- Simple demonstration of the Ellipsoid online learner on mnist
- @author: shais
- """
- import numpy as np
- import matplotlib.pyplot as plt
- dot = np.dot;
- sign = np.sign;
- outer = np.outer;
- zeros = np.zeros;
- eye = np.eye;
- sqrt = np.sqrt;
- #%%
- # read data
- dataDir = "/Users/shais/data/mnist/mnist/";
- X = np.loadtxt(dataDir + "train4vs7_data.txt.gz");
- Y = np.loadtxt(dataDir + "train4vs7_labels.txt.gz");
- d,n = X.shape;
- #%%
- # show some images
- plt.figure(1);
- for i in range (1,26):
- ax = plt.subplot(5,5,i);
- ax.axis('off');
- if Y[i]>0:
- ax.imshow(X[:,i].reshape(28,28),cmap="gray");
- else:
- ax.imshow(255-X[:,i].reshape(28,28),cmap="gray");
- plt.draw();
- #%%
- # Initial Ellipsoid learner
- w = zeros((d,));
- A = eye(d);
- M = 0; # counts mistakes
- #%%
- # Loop Ellipsoid over data
- eta = d*d/(d*d-1.0);
- for t in range(0,n):
- yhat = sign(dot(w,X[:,t]));
- if Y[t] != yhat:
- M = M+1;
- Ax = dot(A , X[:,t]);
- xAx = dot(X[:,t] , Ax);
- w = w + Y[t]/((d+1)*sqrt(xAx)) * Ax;
- A = eta*( A - (2.0/((d+1.0)*xAx)) * outer(Ax,Ax) );
- #%%
- # show the mask learnt by ellipsoid
- plt.figure(2);
- ax1 = plt.subplot(1,2,1);
- ax1.axis('off'); # no need for axis marks
- ax2 = plt.subplot(1,2,2);
- ax2.axis('off'); # no need for axis marks
- ax1.imshow(w.reshape(28,28),cmap="gray");
- tmp = 1/(1+np.exp(-10*w/w.max()));
- ax2.imshow(tmp.reshape(28,28),cmap="gray");
- plt.draw();
- #%%
复制代码
|
|