楼主: ReneeBK
941 2

PLS Regression using JavaScript [推广有奖]

  • 1关注
  • 62粉丝

VIP

已卖:4898份资源

学术权威

14%

还不是VIP/贵宾

-

TA的文库  其他...

R资源总汇

Panel Data Analysis

Experimental Design

威望
1
论坛币
49640 个
通用积分
55.8137
学术水平
370 点
热心指数
273 点
信用等级
335 点
经验
57805 点
帖子
4005
精华
21
在线时间
582 小时
注册时间
2005-5-8
最后登录
2023-11-26

楼主
ReneeBK 发表于 2016-6-26 23:57:05 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币
  1. PLS regression algorithm based on the Yi Cao Matlab implementation:

  2. Partial Least-Squares and Discriminant Analysis

  3. installation

  4. $ npm install ml-pls

  5. Methods

  6. new PLS(X, Y)

  7. pls.train(options)

  8. Example

  9. var X = [[0.1, 0.02], [0.25, 1.01] ,[0.95, 0.01], [1.01, 0.96]];
  10. var Y = [[1, 0], [1, 0], [1, 0], [0, 1]];
  11. var options = {
  12.   latentVectors: 10,
  13.   tolerance: 1e-4
  14. };

  15. var pls = new PLS(X, Y);
  16. pls.train(options);
  17. predict(dataset)

  18. Predict the values of the dataset.

  19. Arguments

  20. dataset - A matrix that contains the dataset.
  21. Example

  22. var dataset = [[0, 0], [0, 1], [1, 0], [1, 1]];

  23. var ans = pls.predict(dataset);
  24. getExplainedVariance()

  25. Returns the explained variance on training

  26. License

  27. MIT
复制代码

本帖隐藏的内容

pls-master.zip (8.57 KB)


二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:Javascript regression regressio regress script matrix values

本帖被以下文库推荐

沙发
ReneeBK 发表于 2016-6-26 23:57:40
  1. 'use strict';

  2. var Matrix = require('ml-matrix');
  3. var Utils = require('./utils');

  4. class PLS {
  5.     constructor(X, Y) {
  6.         if (X === true) {
  7.             const model = Y;
  8.             this.meanX = model.meanX;
  9.             this.stdDevX = model.stdDevX;
  10.             this.meanY = model.meanY;
  11.             this.stdDevY = model.stdDevY;
  12.             this.PBQ = Matrix.checkMatrix(model.PBQ);
  13.             this.R2X = model.R2X;
  14.         } else {
  15.             if (X.length !== Y.length)
  16.                 throw new RangeError('The number of X rows must be equal to the number of Y rows');

  17.             const resultX = Utils.featureNormalize(X);
  18.             this.X = resultX.result;
  19.             this.meanX = resultX.means;
  20.             this.stdDevX = resultX.std;

  21.             const resultY = Utils.featureNormalize(Y);
  22.             this.Y = resultY.result;
  23.             this.meanY = resultY.means;
  24.             this.stdDevY = resultY.std;
  25.         }
  26.     }

  27.     /**
  28.      * Fits the model with the given data and predictions, in this function is calculated the
  29.      * following outputs:
  30.      *
  31.      * T - Score matrix of X
  32.      * P - Loading matrix of X
  33.      * U - Score matrix of Y
  34.      * Q - Loading matrix of Y
  35.      * B - Matrix of regression coefficient
  36.      * W - Weight matrix of X
  37.      *
  38.      * @param {Object} options - recieves the latentVectors and the tolerance of each step of the PLS
  39.      */
  40.     train(options) {
  41.         if(options === undefined) options = {};

  42.         var latentVectors = options.latentVectors;
  43.         if (latentVectors === undefined) {
  44.             latentVectors = Math.min(this.X.length - 1, this.X[0].length);
  45.         }

  46.         var tolerance = options.tolerance;
  47.         if (tolerance === undefined) {
  48.             tolerance = 1e-5;
  49.         }
  50.         
  51.         var X = this.X;
  52.         var Y = this.Y;

  53.         var rx = X.rows;
  54.         var cx = X.columns;
  55.         var ry = Y.rows;
  56.         var cy = Y.columns;

  57.         var ssqXcal = X.clone().mul(X).sum(); // for the r2
  58.         var sumOfSquaresY = Y.clone().mul(Y).sum();

  59.         var n = latentVectors; //Math.max(cx, cy); // components of the pls
  60.         var T = Matrix.zeros(rx, n);
  61.         var P = Matrix.zeros(cx, n);
  62.         var U = Matrix.zeros(ry, n);
  63.         var Q = Matrix.zeros(cy, n);
  64.         var B = Matrix.zeros(n, n);
  65.         var W = P.clone();
  66.         var k = 0;

  67.         while(Utils.norm(Y) > tolerance && k < n) {
  68.             var transposeX = X.transpose();
  69.             var transposeY = Y.transpose();

  70.             var tIndex = maxSumColIndex(X.clone().mulM(X));
  71.             var uIndex = maxSumColIndex(Y.clone().mulM(Y));

  72.             var t1 = X.getColumnVector(tIndex);
  73.             var u = Y.getColumnVector(uIndex);
  74.             var t = Matrix.zeros(rx, 1);

  75.             while(Utils.norm(t1.clone().sub(t)) > tolerance) {
  76.                 var w = transposeX.mmul(u);
  77.                 w.div(Utils.norm(w));
  78.                 t = t1;
  79.                 t1 = X.mmul(w);
  80.                 var q = transposeY.mmul(t1);
  81.                 q.div(Utils.norm(q));
  82.                 u = Y.mmul(q);
  83.             }

  84.             t = t1;
  85.             var num = transposeX.mmul(t);
  86.             var den = (t.transpose().mmul(t))[0][0];
  87.             var p = num.div(den);
  88.             var pnorm = Utils.norm(p);
  89.             p.div(pnorm);
  90.             t.mul(pnorm);
  91.             w.mul(pnorm);

  92.             num = u.transpose().mmul(t);
  93.             den = (t.transpose().mmul(t))[0][0];
  94.             var b = (num.div(den))[0][0];
  95.             X.sub(t.mmul(p.transpose()));
  96.             Y.sub(t.clone().mul(b).mmul(q.transpose()));

  97.             T.setColumn(k, t);
  98.             P.setColumn(k, p);
  99.             U.setColumn(k, u);
  100.             Q.setColumn(k, q);
  101.             W.setColumn(k, w);

  102.             B[k][k] = b;
  103.             k++;
  104.         }

  105.         k--;
  106.         T = T.subMatrix(0, T.rows - 1, 0, k);
  107.         P = P.subMatrix(0, P.rows - 1, 0, k);
  108.         U = U.subMatrix(0, U.rows - 1, 0, k);
  109.         Q = Q.subMatrix(0, Q.rows - 1, 0, k);
  110.         W = W.subMatrix(0, W.rows - 1, 0, k);
  111.         B = B.subMatrix(0, k, 0, k);

  112.         // TODO: review of R2Y
  113.         //this.R2Y = t.transpose().mmul(t).mul(q[k][0]*q[k][0]).divS(ssqYcal)[0][0];

  114.         this.ssqYcal = sumOfSquaresY;
  115.         this.E = X;
  116.         this.F = Y;
  117.         this.T = T;
  118.         this.P = P;
  119.         this.U = U;
  120.         this.Q = Q;
  121.         this.W = W;
  122.         this.B = B;
  123.         this.PBQ = P.mmul(B).mmul(Q.transpose());
  124.         this.R2X = t.transpose().mmul(t).mmul(p.transpose().mmul(p)).div(ssqXcal)[0][0];
  125.     }

  126.     /**
  127.      * Predicts the behavior of the given dataset.
  128.      * @param dataset - data to be predicted.
  129.      * @returns {Matrix} - predictions of each element of the dataset.
  130.      */
  131.     predict(dataset) {
  132.         var X = Matrix.checkMatrix(dataset);
  133.         X = X.subRowVector(this.meanX).divRowVector(this.stdDevX);
  134.         var Y = X.mmul(this.PBQ);
  135.         Y = Y.mulRowVector(this.stdDevY).addRowVector(this.meanY);
  136.         return Y;
  137.     }

  138.     /**
  139.      * Returns the explained variance on training of the PLS model
  140.      * @return {number}
  141.      */
  142.     getExplainedVariance() {
  143.         return this.R2X;
  144.     }
  145.    
  146.     toJSON() {
  147.         return {
  148.             name: 'PLS',
  149.             R2X: this.R2X,
  150.             meanX: this.meanX,
  151.             stdDevX: this.stdDevX,
  152.             meanY: this.meanY,
  153.             stdDevY: this.stdDevY,
  154.             PBQ: this.PBQ,
  155.         };
  156.     }

  157.     /**
  158.      * Load a PLS model from a JSON Object
  159.      * @param model
  160.      * @return {PLS} - PLS object from the given model
  161.      */
  162.     static load(model) {
  163.         if (model.name !== 'PLS')
  164.             throw new RangeError('Invalid model: ' + model.name);
  165.         return new PLS(true, model);
  166.     }
  167. }

  168. module.exports = PLS;

  169. /**
  170. * Retrieves the sum at the column of the given matrix.
  171. * @param matrix
  172. * @param column
  173. * @returns {number}
  174. */
  175. function getColSum(matrix, column) {
  176.     var sum = 0;
  177.     for (var i = 0; i < matrix.rows; i++) {
  178.         sum += matrix[i][column];
  179.     }
  180.     return sum;
  181. }

  182. /**
  183. * Function that returns the index where the sum of each
  184. * column vector is maximum.
  185. * @param {Matrix} data
  186. * @returns {number} index of the maximum
  187. */
  188. function maxSumColIndex(data) {
  189.     var maxIndex = 0;
  190.     var maxSum = -Infinity;
  191.     for(var i = 0; i < data.columns; ++i) {
  192.         var currentSum = getColSum(data, i);
  193.         if(currentSum > maxSum) {
  194.             maxSum = currentSum;
  195.             maxIndex = i;
  196.         }
  197.     }
  198.     return maxIndex;
  199. }
复制代码

藤椅
reflets 发表于 2016-6-27 00:40:36
感谢分享

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
jg-xs1
拉您进交流群
GMT+8, 2026-1-8 13:01