楼主: ReneeBK
1081 5

RecurrentJava [推广有奖]

  • 1关注
  • 62粉丝

VIP

已卖:4897份资源

学术权威

14%

还不是VIP/贵宾

-

TA的文库  其他...

R资源总汇

Panel Data Analysis

Experimental Design

威望
1
论坛币
49635 个
通用积分
55.6937
学术水平
370 点
热心指数
273 点
信用等级
335 点
经验
57805 点
帖子
4005
精华
21
在线时间
582 小时
注册时间
2005-5-8
最后登录
2023-11-26

楼主
ReneeBK 发表于 2016-9-3 08:18:39 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币
RecurrentJava

RecurrentJava is a reimplementation of Andrej Karpathy's RecurrentJS, in Java.

It currently features:

  • Deep Recurrent Neural Networks
  • Long Short-Term Memory Networks
  • Gated Recurrent Unit Neural Networks
  • Backpropagation Through Time handled via Automatic Differentiation.

ExamplePaulGraham.java shows how to do character-by-character sentence prediction and generation.

Sample output:

========================================REPORT:calculating perplexity over entire data set...Median Perplexity = 1.4959Temperature 1.0 prediction:    "there's a more kemmaces of meanness that hade? tagh o; mool"    "it fart dect about twish i could see gve..."Temperature 0.75 prediction:    "that's not absolutely note a lot of the startup? path they'll should owt"    "i realize how crazy all thi..."Temperature 0.5 prediction:    "the most stripiess to more here that happens never get them"    "if you do that role kropate that's the w..."Temperature 0.25 prediction:    "the person who needs something making the same spignf befart"    "the startup founders who never about wh..."Temperature 0.1 prediction:    "the startup founders who never about which in your expanding, it's a sign when idea way we don't the..."Argmax prediction:    "the problem is not that most towns kill startups"    "the problem is not that most towns kill startups"    "th..."========================================License

MIT

本帖隐藏的内容

RecurrentJava-master.zip (11.5 MB)


二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:Recurrent CURRENT curr Java jav prediction currently features Memory entire

沙发
ReneeBK 发表于 2016-9-3 08:20:35
  1. package model;
  2. import java.util.ArrayList;
  3. import java.util.List;
  4. import java.util.Random;

  5. import matrix.Matrix;
  6. import autodiff.Graph;


  7. public class FeedForwardLayer implements Model {

  8.         private static final long serialVersionUID = 1L;
  9.         Matrix W;
  10.         Matrix b;
  11.         Nonlinearity f;
  12.        
  13.         public FeedForwardLayer(int inputDimension, int outputDimension, Nonlinearity f, double initParamsStdDev, Random rng) {
  14.                 W = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  15.                 b = new Matrix(outputDimension);
  16.                 this.f = f;
  17.         }
  18.        
  19.         @Override
  20.         public Matrix forward(Matrix input, Graph g) throws Exception {
  21.                 Matrix sum = g.add(g.mul(W, input), b);
  22.                 Matrix out = g.nonlin(f, sum);
  23.                 return out;
  24.         }

  25.         @Override
  26.         public void resetState() {

  27.         }

  28.         @Override
  29.         public List<Matrix> getParameters() {
  30.                 List<Matrix> result = new ArrayList<>();
  31.                 result.add(W);
  32.                 result.add(b);
  33.                 return result;
  34.         }
  35. }
复制代码

藤椅
ReneeBK 发表于 2016-9-3 08:21:15
  1. package model;
  2. import java.util.ArrayList;
  3. import java.util.List;
  4. import java.util.Random;

  5. import matrix.Matrix;
  6. import autodiff.Graph;

  7. /*
  8. * As described in:
  9. *         "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation"
  10. *         http://arxiv.org/abs/1406.1078
  11. */

  12. public class GruLayer implements Model {

  13.         private static final long serialVersionUID = 1L;
  14.         int inputDimension;
  15.         int outputDimension;
  16.        
  17.         Matrix IHmix, HHmix, Bmix;
  18.         Matrix IHnew, HHnew, Bnew;
  19.         Matrix IHreset, HHreset, Breset;
  20.        
  21.         Matrix context;
  22.        
  23.         Nonlinearity fMix = new SigmoidUnit();
  24.         Nonlinearity fReset = new SigmoidUnit();
  25.         Nonlinearity fNew = new TanhUnit();
  26.        
  27.         public GruLayer(int inputDimension, int outputDimension, double initParamsStdDev, Random rng) {
  28.                 this.inputDimension = inputDimension;
  29.                 this.outputDimension = outputDimension;
  30.                 IHmix = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  31.                 HHmix = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
  32.                 Bmix = new Matrix(outputDimension);
  33.                 IHnew = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  34.                 HHnew = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
  35.                 Bnew = new Matrix(outputDimension);
  36.                 IHreset = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  37.                 HHreset = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
  38.                 Breset= new Matrix(outputDimension);
  39.         }
  40.        
  41.         @Override
  42.         public Matrix forward(Matrix input, Graph g) throws Exception {
  43.                
  44.                 Matrix sum0 = g.mul(IHmix, input);
  45.                 Matrix sum1 = g.mul(HHmix, context);
  46.                 Matrix actMix = g.nonlin(fMix, g.add(g.add(sum0, sum1), Bmix));

  47.                 Matrix sum2 = g.mul(IHreset, input);
  48.                 Matrix sum3 = g.mul(HHreset, context);
  49.                 Matrix actReset = g.nonlin(fReset, g.add(g.add(sum2, sum3), Breset));
  50.                
  51.                 Matrix sum4 = g.mul(IHnew, input);
  52.                 Matrix gatedContext = g.elmul(actReset, context);
  53.                 Matrix sum5 = g.mul(HHnew, gatedContext);
  54.                 Matrix actNewPlusGatedContext = g.nonlin(fNew, g.add(g.add(sum4, sum5), Bnew));
  55.                
  56.                 Matrix memvals = g.elmul(actMix, context);
  57.                 Matrix newvals = g.elmul(g.oneMinus(actMix), actNewPlusGatedContext);
  58.                 Matrix output = g.add(memvals, newvals);
  59.                
  60.                 //rollover activations for next iteration
  61.                 context = output;
  62.                
  63.                 return output;
  64.         }

  65.         @Override
  66.         public void resetState() {
  67.                 context = new Matrix(outputDimension);
  68.         }

  69.         @Override
  70.         public List<Matrix> getParameters() {
  71.                 List<Matrix> result = new ArrayList<>();
  72.                 result.add(IHmix);
  73.                 result.add(HHmix);
  74.                 result.add(Bmix);
  75.                 result.add(IHnew);
  76.                 result.add(HHnew);
  77.                 result.add(Bnew);
  78.                 result.add(IHreset);
  79.                 result.add(HHreset);
  80.                 result.add(Breset);
  81.                 return result;
  82.         }

  83. }
复制代码

板凳
ReneeBK 发表于 2016-9-3 08:21:57
  1. package model;
  2. import java.util.ArrayList;
  3. import java.util.List;
  4. import java.util.Random;

  5. import matrix.Matrix;
  6. import autodiff.Graph;


  7. public class LinearLayer implements Model {

  8.         private static final long serialVersionUID = 1L;
  9.         Matrix W;
  10.         //no biases
  11.        
  12.         public LinearLayer(int inputDimension, int outputDimension, double initParamsStdDev, Random rng) {
  13.                 W = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  14.         }
  15.        
  16.         @Override
  17.         public Matrix forward(Matrix input, Graph g) throws Exception {
  18.                 Matrix out = g.mul(W, input);
  19.                 return out;
  20.         }

  21.         @Override
  22.         public void resetState() {

  23.         }

  24.         @Override
  25.         public List<Matrix> getParameters() {
  26.                 List<Matrix> result = new ArrayList<>();
  27.                 result.add(W);
  28.                 return result;
  29.         }
  30. }
复制代码

报纸
ReneeBK 发表于 2016-9-3 08:23:22
  1. package model;
  2. import java.util.ArrayList;
  3. import java.util.List;
  4. import java.util.Random;

  5. import matrix.Matrix;
  6. import autodiff.Graph;

  7. public class LstmLayer implements Model {
  8.        
  9.         private static final long serialVersionUID = 1L;
  10.         int inputDimension;
  11.         int outputDimension;
  12.        
  13.         Matrix Wix, Wih, bi;
  14.         Matrix Wfx, Wfh, bf;
  15.         Matrix Wox, Woh, bo;
  16.         Matrix Wcx, Wch, bc;
  17.        
  18.         Matrix hiddenContext;
  19.         Matrix cellContext;
  20.        
  21.         Nonlinearity fInputGate = new SigmoidUnit();
  22.         Nonlinearity fForgetGate = new SigmoidUnit();
  23.         Nonlinearity fOutputGate = new SigmoidUnit();
  24.         Nonlinearity fCellInput = new TanhUnit();
  25.         Nonlinearity fCellOutput = new TanhUnit();
  26.        
  27.         public LstmLayer(int inputDimension, int outputDimension, double initParamsStdDev, Random rng) {
  28.                 this.inputDimension = inputDimension;
  29.                 this.outputDimension = outputDimension;
  30.                 Wix = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  31.                 Wih = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
  32.                 bi = new Matrix(outputDimension);
  33.                 Wfx = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  34.                 Wfh = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
  35.                 //set forget bias to 1.0, as described here: http://jmlr.org/proceedings/papers/v37/jozefowicz15.pdf
  36.                 bf = Matrix.ones(outputDimension, 1);
  37.                 Wox = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  38.                 Woh = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
  39.                 bo = new Matrix(outputDimension);
  40.                 Wcx = Matrix.rand(outputDimension, inputDimension, initParamsStdDev, rng);
  41.                 Wch = Matrix.rand(outputDimension, outputDimension, initParamsStdDev, rng);
  42.                 bc = new Matrix(outputDimension);
  43.         }

  44.         @Override
  45.         public Matrix forward(Matrix input, Graph g) throws Exception {
  46.                
  47.                 //input gate
  48.                 Matrix sum0 = g.mul(Wix, input);
  49.                 Matrix sum1 = g.mul(Wih, hiddenContext);
  50.                 Matrix inputGate = g.nonlin(fInputGate, g.add(g.add(sum0, sum1), bi));
  51.                
  52.                 //forget gate
  53.                 Matrix sum2 = g.mul(Wfx, input);
  54.                 Matrix sum3 = g.mul(Wfh, hiddenContext);
  55.                 Matrix forgetGate = g.nonlin(fForgetGate, g.add(g.add(sum2, sum3), bf));
  56.                
  57.                 //output gate
  58.                 Matrix sum4 = g.mul(Wox, input);
  59.                 Matrix sum5 = g.mul(Woh, hiddenContext);
  60.                 Matrix outputGate = g.nonlin(fOutputGate, g.add(g.add(sum4, sum5), bo));

  61.                 //write operation on cells
  62.                 Matrix sum6 = g.mul(Wcx, input);
  63.                 Matrix sum7 = g.mul(Wch, hiddenContext);
  64.                 Matrix cellInput = g.nonlin(fCellInput, g.add(g.add(sum6, sum7), bc));
  65.                
  66.                 //compute new cell activation
  67.                 Matrix retainCell = g.elmul(forgetGate, cellContext);
  68.                 Matrix writeCell = g.elmul(inputGate,  cellInput);
  69.                 Matrix cellAct = g.add(retainCell,  writeCell);
  70.                
  71.                 //compute hidden state as gated, saturated cell activations
  72.                 Matrix output = g.elmul(outputGate, g.nonlin(fCellOutput, cellAct));
  73.                
  74.                 //rollover activations for next iteration
  75.                 hiddenContext = output;
  76.                 cellContext = cellAct;
  77.                
  78.                 return output;
  79.         }

  80.         @Override
  81.         public void resetState() {
  82.                 hiddenContext = new Matrix(outputDimension);
  83.                 cellContext = new Matrix(outputDimension);
  84.         }

  85.         @Override
  86.         public List<Matrix> getParameters() {
  87.                 List<Matrix> result = new ArrayList<>();
  88.                 result.add(Wix);
  89.                 result.add(Wih);
  90.                 result.add(bi);
  91.                 result.add(Wfx);
  92.                 result.add(Wfh);
  93.                 result.add(bf);
  94.                 result.add(Wox);
  95.                 result.add(Woh);
  96.                 result.add(bo);
  97.                 result.add(Wcx);
  98.                 result.add(Wch);
  99.                 result.add(bc);
  100.                 return result;
  101.         }
  102. }
复制代码

地板
ReneeBK 发表于 2016-9-3 08:24:58
  1. package model;

  2. import java.util.ArrayList;
  3. import java.util.List;

  4. import matrix.Matrix;
  5. import autodiff.Graph;

  6. public class NeuralNetwork implements Model {

  7.         private static final long serialVersionUID = 1L;
  8.         List<Model> layers = new ArrayList<>();
  9.        
  10.         public NeuralNetwork(List<Model> layers) {
  11.                 this.layers = layers;
  12.         }
  13.        
  14.         @Override
  15.         public Matrix forward(Matrix input, Graph g) throws Exception {
  16.                 Matrix prev = input;
  17.                 for (Model layer : layers) {
  18.                         prev = layer.forward(prev, g);
  19.                 }
  20.                 return prev;
  21.         }

  22.         @Override
  23.         public void resetState() {
  24.                 for (Model layer : layers) {
  25.                         layer.resetState();
  26.                 }
  27.         }

  28.         @Override
  29.         public List<Matrix> getParameters() {
  30.                 List<Matrix> result = new ArrayList<>();
  31.                 for (Model layer : layers) {
  32.                         result.addAll(layer.getParameters());
  33.                 }
  34.                 return result;
  35.         }
  36. }
复制代码

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
jg-xs1
拉您进交流群
GMT+8, 2025-12-31 16:20