楼主: ReneeBK
1769 5

Deep Neural Networks with GPU support [推广有奖]

  • 1关注
  • 62粉丝

VIP

已卖:4897份资源

学术权威

14%

还不是VIP/贵宾

-

TA的文库  其他...

R资源总汇

Panel Data Analysis

Experimental Design

威望
1
论坛币
49635 个
通用积分
55.6937
学术水平
370 点
热心指数
273 点
信用等级
335 点
经验
57805 点
帖子
4005
精华
21
在线时间
582 小时
注册时间
2005-5-8
最后登录
2023-11-26

楼主
ReneeBK 发表于 2016-9-3 08:29:54 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币
Deep Neural Networks with GPU Support

This is a Java implementation of some of the algorithms for training deep neural networks. GPU support is provided via the OpenCL and Aparapi. The architecture is designed with modularity, extensibility and pluggability in mind.

Git structure

I'm using the git-flow model. The most stable (but older) sources are available in the master branch, while the latest ones are in the develop branch.

If you want to use the previous Java 7 compatible version you can check out this release.

Neural network types
  • Multilayer perceptron
  • Restricted Boltzmann Machine
  • Autoencoder
  • Deep belief network
  • Stacked autoencodeer
  • Convolutional networks with max pooling, average poolng and stochastic pooling.
  • Maxout networks (work-in-progress)
Training algorithms
  • Backpropagation - supports multilayer perceptrons, convolutional networks and dropout.
  • Contrastive divergence and persistent contrastive divergence implemented using these and these guidelines.
  • Greedy layer-wise training for deep networks - works for stacked autoencoders and DBNs, but supports any kind of training.

All the algorithms support GPU execution.

Out of the box supported datasets are MNIST, CIFAR-10/CIFAR-100 (experimental, not much testing), IRIS and XOR, but you can easily implement your own.

Experimental support of RGB image preprocessing operations - affine transformations, cropping, and color scaling (see Generaltest.java -> testImageInputProvider).

Activation functions
  • Logistic
  • Tanh
  • Rectifiers
  • Softplus
  • Softmax
  • Weighted sum

All the functions support GPU execution. They can be applied to all types of networks and all training algorithms. You can also implement new activations.

How to build the library
  • Java 8.
  • To build the project you need gradle or maven. If you don't use any of these you can go to the project folder and execute thegradlew console command, which will automatically setup gradle environment for you.
  • I'm also uploading the latest jar file (with bundled dependencies and sources) here.
  • Depending on your environment you might need to download the relevant aparapi .dll or .so file (located in the root of each archive) from here and add it's location to the system PATH variable. (This)[https://code.google.com/p/aparapi/wiki/DevelopersGuideLinux] is a guide on how to set up OpenCL in linux environment.
How to run the samples

The samples are organized as unit tests. If you want see examples on various popular datasets you can go to nn-samples/src/test/java/com/github/neuralnetworks/samples/.

Library structure

There are two projects:

  • nn-core - contains the full implementation.
  • nn-samples - contains implementations of popular datasets and

The software design is tiered, each tier depending on the previous ones.

Network architecture

This is the first "tier". Each network is defined by a list of layers. Each layer has a set of connections that link it to the other layers of the network, making the network a directed acyclic graph. This structure can accommodate simple feedforwad nets, but also more complex architectures like http://www.cs.toronto.edu/~hinton/absps/imagenet.pdf. You can build your own specific network.

Data propagation

This tier is propagating data through the network. It takes advantage of it's graph structure. There are two main base components:

  • LayerCalculator - propagates data through the graph. It receives target layer and input data clamped to a given layer (considered an input layer). It ensures that the data is propagated through the layers in the correct order and that all the connections in the graph are calculated. For example, during the feedforward phase of backpropagation the training data is clamped to the input layer and is propagated to the target layer (the output layer of the network). In the bp phase the output error derivative is clamped as "input" to the layer and the weights are updated using breadth-first graph traversal starting from the output layer. Essentially the role of the LayerCalculator is to provide the order, in which the network layers are calculated.
  • ConnectionCalculator - base class for all neuron types (sigmoid, rectifiers, convolutional etc.). After the order of calculation of the layers is determined by the LayerCalculator, then the list of input connections for each layer is calculated by theConnectionCalculator.
GPU

Most of the ConnectionCalculator implementations are optimized for GPU execution. Aparapi imposes some important restrictions on the code that can be executed on the GPU. The most significant are:

  • only one-dimensional arrays (and variables) of primitive data types are allowed. It is not possible to use complex objects.
  • only member-methods of the Aparapi Kernel class itself are allowed to be called from the GPU executable code.

Therefore before each GPU calculation all the data is converted to one-dim arrays and primitive type variables. Because of this all Aparapi neuron types are using either AparapiWeightedSum (for fully connected layers and weighted sum input functions),AparapiSubsampling2D (for subsampling layers) or AparapiConv2D (for convolutional layers). Most of the data is represented as one-dimensional array by default (for example Matrix).

Training

All the trainers are using the Trainer base class. They are optimized to run on the GPU, but you can plug-in other implementations and new training algorithms. The training procedure has training and testing phases. Each Trainer receives parameters (for example learning rate, momentum, etc) via Properties (a HashMap). For the supported properties for each trainer please check the TrainerFactory class.

Input data

Input is provided to the neural network by the trainers via TrainingInputProvider interface. Each TrainingInputProvider provides training samples in the form of TrainingInputData (default implementation is TrainingInputDataImpl). The input can be modified by a list of modifiers - for example MeanInputFunction (for subtracting the mean value) and ScalingInputFunction (scaling within a range). Currently MnistInputProvider and IrisInputProvider are implemented.

Author

Ivan Vasilev (ivanvasilev [at] gmail (dot) com)

License

MIT License

本帖隐藏的内容

neuralnetworks-master (1).zip (5.36 MB)


二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:Networks network Support Neural Works available designed previous provided training

沙发
ReneeBK 发表于 2016-9-3 08:32:15
  1. package com.github.neuralnetworks.samples.cifar;

  2. import java.awt.image.BufferedImage;
  3. import java.awt.image.DataBufferByte;
  4. import java.io.File;
  5. import java.io.FileNotFoundException;
  6. import java.io.IOException;
  7. import java.io.RandomAccessFile;
  8. import java.util.ArrayList;
  9. import java.util.List;

  10. import com.github.neuralnetworks.input.ImageInputProvider;
  11. import com.github.neuralnetworks.util.Util;

  12. /**
  13. * Input provider for the CIFAR-10 and CIFAR-100 datasets. Requires location of the CIFAR images
  14. * files (not included in the library). Do not use this class directly, but use the subclasses instead
  15. * Experimental
  16. */
  17. public abstract class CIFARInputProvider extends ImageInputProvider {

  18.     private static final long serialVersionUID = 1L;

  19.     protected RandomAccessFile files[];
  20.     protected int labelSize;
  21.     protected int inputSize;
  22.     protected byte[] nextInputRaw;
  23.     protected float[] nextTarget;
  24.     private List<Integer> elementsOrder;

  25.     private CIFARInputProvider() {
  26.         super();
  27.         this.elementsOrder = new ArrayList<>();
  28.         this.nextInputRaw = new byte[3072];
  29.     }

  30.     @Override
  31.     public int getInputSize() {
  32.         return inputSize;
  33.     }

  34.     @Override
  35.     public float[] getNextTarget() {
  36.         return nextTarget;
  37.     }

  38.     @Override
  39.     public float[] getNextInput() {
  40.         // if no transformations are required and the data is grouped by color
  41.         // channel the code can be optimized
  42.         if (!requireAugmentation() && getProperties().getGroupByChannel()) {
  43.             if (nextInput == null) {
  44.                 nextInput = new float[3072];
  45.             }

  46.             float scaleColors = getProperties().getScaleColors() ? 255 : 1;
  47.             for (int i = 0; i < nextInput.length; i++) {
  48.                 nextInput[i] = (nextInputRaw[i] & 0xFF) / scaleColors;
  49.             }

  50.             return nextInput;
  51.         }

  52.         return super.getNextInput();
  53.     }

  54.     @Override
  55.     protected BufferedImage getNextImage() {
  56.         BufferedImage image = new BufferedImage(32, 32, BufferedImage.TYPE_3BYTE_BGR);
  57.         byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();

  58.         for (int i = 0; i < 1024; i++) {
  59.             pixels[i * 3] = nextInputRaw[1024 * 2 + i];
  60.             pixels[i * 3 + 1] = nextInputRaw[1024 + i];
  61.             pixels[i * 3 + 2] = nextInputRaw[i];
  62.         }

  63.         return image;
  64.     }

  65.     @Override
  66.     public void beforeSample() {
  67.         if (elementsOrder.size() == 0) {
  68.             resetOrder();
  69.         }

  70.         int currentEl = elementsOrder.remove(getProperties().getUseRandomOrder() ? getProperties().getRandom().nextInt(elementsOrder.size()) : 0);
  71.         int id = currentEl % (getInputSize() / files.length);

  72.         RandomAccessFile f = files[currentEl / (getInputSize() / files.length)];

  73.         try {
  74.             f.seek(id * (3072 + labelSize));
  75.             if (labelSize > 1) {
  76.                 f.readUnsignedByte();
  77.             }

  78.             Util.fillArray(nextTarget, 0);
  79.             nextTarget[f.readUnsignedByte()] = 1;

  80.             f.readFully(nextInputRaw);
  81.         } catch (IOException e) {
  82.             e.printStackTrace();
  83.         }
  84.     }

  85.     @Override
  86.     public void reset() {
  87.         super.reset();
  88.         resetOrder();
  89.     }

  90.     public void resetOrder() {
  91.         elementsOrder = new ArrayList<Integer>(getInputSize());
  92.         for (int i = 0; i < getInputSize(); i++) {
  93.             elementsOrder.add(i);
  94.         }
  95.     }

  96.     public static class CIFAR10TrainingInputProvider extends CIFARInputProvider {

  97.         private static final long serialVersionUID = 1L;

  98.         /**
  99.          * @param directory - the folder where the CIFAR files are located
  100.          */
  101.         public CIFAR10TrainingInputProvider(String directory) {
  102.             super();

  103.             this.labelSize = 1;
  104.             this.inputSize = 50000;
  105.             this.nextTarget = new float[10];
  106.             this.files = new RandomAccessFile[5];

  107.             try {
  108.                 if (!directory.endsWith(File.separator)) {
  109.                     directory += File.separator;
  110.                 }

  111.                 files[0] = new RandomAccessFile(directory + "data_batch_1.bin", "r");
  112.                 files[1] = new RandomAccessFile(directory + "data_batch_2.bin", "r");
  113.                 files[2] = new RandomAccessFile(directory + "data_batch_3.bin", "r");
  114.                 files[3] = new RandomAccessFile(directory + "data_batch_4.bin", "r");
  115.                 files[4] = new RandomAccessFile(directory + "data_batch_5.bin", "r");
  116.             } catch (FileNotFoundException e) {
  117.                 e.printStackTrace();
  118.             }
  119.         }
  120.     }

  121.     public static class CIFAR10TestingInputProvider extends CIFARInputProvider {

  122.         private static final long serialVersionUID = 1L;

  123.         /**
  124.          * @param directory - the folder where the CIFAR files are located
  125.          */
  126.         public CIFAR10TestingInputProvider(String directory) {
  127.             super();

  128.             this.labelSize = 1;
  129.             this.inputSize = 10000;
  130.             this.nextTarget = new float[10];
  131.             this.files = new RandomAccessFile[1];

  132.             try {
  133.                 if (!directory.endsWith(File.separator)) {
  134.                     directory += File.separator;
  135.                 }

  136.                 files[0] = new RandomAccessFile(directory + "test_batch.bin", "r");
  137.             } catch (FileNotFoundException e) {
  138.                 e.printStackTrace();
  139.             }
  140.         }
  141.     }

  142.     public static class CIFAR100TrainingInputProvider extends CIFARInputProvider {

  143.         private static final long serialVersionUID = 1L;

  144.         /**
  145.          * @param directory - the folder where the CIFAR files are located
  146.          */
  147.         public CIFAR100TrainingInputProvider(String directory) {
  148.             super();

  149.             this.labelSize = 2;
  150.             this.inputSize = 50000;
  151.             this.nextTarget = new float[100];
  152.             this.files = new RandomAccessFile[5];

  153.             try {
  154.                 if (!directory.endsWith(File.separator)) {
  155.                     directory += File.separator;
  156.                 }

  157.                 files[0] = new RandomAccessFile(directory + "data_batch_1.bin", "r");
  158.                 files[1] = new RandomAccessFile(directory + "data_batch_2.bin", "r");
  159.                 files[2] = new RandomAccessFile(directory + "data_batch_3.bin", "r");
  160.                 files[3] = new RandomAccessFile(directory + "data_batch_4.bin", "r");
  161.                 files[4] = new RandomAccessFile(directory + "data_batch_5.bin", "r");
  162.             } catch (FileNotFoundException e) {
  163.                 e.printStackTrace();
  164.             }
  165.         }
  166.     }

  167.     public static class CIFAR100TestingInputProvider extends CIFARInputProvider {
  168.        
  169.         private static final long serialVersionUID = 1L;

  170.         /**
  171.          * @param directory - the folder where the CIFAR files are located
  172.          */
  173.         public CIFAR100TestingInputProvider(String directory) {
  174.             super();

  175.             this.labelSize = 2;
  176.             this.inputSize = 10000;
  177.             this.nextTarget = new float[100];
  178.             this.files = new RandomAccessFile[1];

  179.             try {
  180.                 if (!directory.endsWith(File.separator)) {
  181.                     directory += File.separator;
  182.                 }

  183.                 files[0] = new RandomAccessFile(directory + "test_batch.bin", "r");
  184.             } catch (FileNotFoundException e) {
  185.                 e.printStackTrace();
  186.             }
  187.         }
  188.     }
  189. }
复制代码

藤椅
ReneeBK 发表于 2016-9-3 08:34:44
  1. package com.github.neuralnetworks.architecture;

  2. import java.io.Serializable;
  3. import java.util.List;
  4. import java.util.Set;

  5. import com.github.neuralnetworks.calculation.LayerCalculator;

  6. /**
  7. * this interface is implemented by everything that wants to present itself as a
  8. * black box with with a list of input/output layers for example these could be
  9. * whole neural network taking part in committee of machines, single
  10. * convolutional/subsamplingo layers or even a single connection between the layers
  11. */
  12. public interface NeuralNetwork extends Serializable {

  13.     /**
  14.      * input layer
  15.      */
  16.     public Layer getInputLayer();

  17.     /**
  18.      * @return output layer
  19.      */
  20.     public Layer getOutputLayer();

  21.     /**
  22.      * @return all the layers in this network
  23.      */
  24.     public Set<Layer> getLayers();

  25.     /**
  26.      * @return all the connections in this network
  27.      */
  28.     public List<Connections> getConnections();

  29.     /**
  30.      * LayerCalculator associated to this network
  31.      */
  32.     public LayerCalculator getLayerCalculator();
  33. }
复制代码

板凳
ReneeBK 发表于 2016-9-3 08:36:14
  1. package com.github.neuralnetworks.architecture;

  2. import java.util.Collection;
  3. import java.util.List;
  4. import java.util.Set;

  5. import com.github.neuralnetworks.calculation.LayerCalculator;
  6. import com.github.neuralnetworks.util.Constants;
  7. import com.github.neuralnetworks.util.Properties;
  8. import com.github.neuralnetworks.util.UniqueList;
  9. import com.github.neuralnetworks.util.Util;

  10. /**
  11. * Base class for all types of neural networks. A neural network is defined only
  12. * by the layers it contains. The layers themselves contain the connections with
  13. * the other layers.
  14. */
  15. public class NeuralNetworkImpl implements NeuralNetwork {

  16.     private static final long serialVersionUID = 1L;

  17.     private Set<Layer> layers;
  18.     private Properties properties;

  19.     public NeuralNetworkImpl() {
  20.         super();
  21.         this.layers = new UniqueList<Layer>();
  22.     }

  23.     @Override
  24.     public LayerCalculator getLayerCalculator() {
  25.         return properties != null ? properties.getParameter(Constants.LAYER_CALCULATOR) : null;
  26.     }

  27.     public void setLayerCalculator(LayerCalculator layerCalculator) {
  28.         if (properties == null) {
  29.             properties = new Properties();
  30.         }

  31.         properties.setParameter(Constants.LAYER_CALCULATOR, layerCalculator);
  32.     }

  33.     @Override
  34.     public Set<Layer> getLayers() {
  35.         return layers;
  36.     }

  37.     public void setLayers(Set<Layer> layers) {
  38.         this.layers = layers;
  39.     }

  40.     public Properties getProperties() {
  41.         return properties;
  42.     }

  43.     public void setProperties(Properties properties) {
  44.         this.properties = properties;
  45.     }

  46.     /*
  47.      * (non-Javadoc)
  48.      *
  49.      * @see com.github.neuralnetworks.architecture.NeuralNetwork#getInputLayer()
  50.      * Default implementation - the input layer is that layer, which doesn't
  51.      * have any inbound connections
  52.      */
  53.     @Override
  54.     public Layer getInputLayer() {
  55.         return layers.stream().filter(l -> l.getConnections(this).stream().noneMatch(c -> l == c.getOutputLayer() && !Util.isBias(c.getInputLayer()))).findFirst().orElse(null);
  56.     }

  57.     @Override
  58.     public Layer getOutputLayer() {
  59.         return getNoOutboundConnectionsLayer();
  60.     }

  61.     protected Layer getNoOutboundConnectionsLayer() {
  62.         return layers.stream().filter(l -> l.getConnections(this).stream().noneMatch(c -> l == c.getInputLayer())).findFirst().orElse(null);
  63.     }

  64.     /*
  65.      * (non-Javadoc)
  66.      *
  67.      * @see
  68.      * com.github.neuralnetworks.architecture.NeuralNetwork#getConnections()
  69.      * Returns list of all the connections within the network. The list is
  70.      * retrieved by iterating over all the layers. Only connections that have
  71.      * both layers in this network are returned.
  72.      */
  73.     @Override
  74.     public List<Connections> getConnections() {
  75.         List<Connections> result = new UniqueList<>();
  76.         if (layers != null) {
  77.             layers.forEach(l -> result.addAll(l.getConnections(this)));
  78.         }

  79.         return result;
  80.     }

  81.     /**
  82.      * @param inputLayer
  83.      * @param outputLayer
  84.      * @return Connection between the two layers if it exists
  85.      */
  86.     public Connections getConnection(Layer inputLayer, Layer outputLayer) {
  87.         return getConnections().stream().filter(c -> (c.getInputLayer() == inputLayer && c.getOutputLayer() == outputLayer) || (c.getInputLayer() == outputLayer && c.getOutputLayer() == inputLayer)).findFirst().orElse(null);
  88.     }

  89.     /**
  90.      * Add layer to the network
  91.      *
  92.      * @param layer
  93.      * @return whether the layer was added successfully
  94.      */
  95.     public boolean addLayer(Layer layer) {
  96.         if (layer != null) {
  97.             if (layers == null) {
  98.                 layers = new UniqueList<>();
  99.             }

  100.             if (!layers.contains(layer)) {
  101.                 layers.add(layer);
  102.                 return true;
  103.             }
  104.         }

  105.         return false;
  106.     }

  107.     /**
  108.      * Remove layer from the network
  109.      *
  110.      * @param layer
  111.      */
  112.     public void removeLayer(Layer layer) {
  113.         if (layer != null) {
  114.             if (layers != null) {
  115.                 // remove layer and bias layers
  116.                 layers.remove(layer);
  117.                 layer.getConnections(this).stream().map(Connections::getInputLayer).filter(l -> Util.isBias(l)).forEach(l -> layers.remove(l));
  118.             }
  119.         }
  120.     }

  121.     /**
  122.      * Add layers to the network
  123.      *
  124.      * @param newLayers
  125.      */
  126.     public void addLayers(Collection<Layer> newLayers) {
  127.         if (newLayers != null) {
  128.             if (layers == null) {
  129.                 layers = new UniqueList<>();
  130.             }

  131.             newLayers.stream().filter(l -> !layers.contains(l)).forEach(l -> layers.add(l));
  132.         }
  133.     }

  134.     /**
  135.      * Add connection to the network - this means adding both input and output
  136.      * layers to the network
  137.      *
  138.      * @param connection
  139.      */
  140.     public void addConnections(Connections... connections) {
  141.         if (connections != null) {
  142.             for (Connections c : connections) {
  143.                 addLayer(c.getInputLayer());
  144.                 addLayer(c.getOutputLayer());
  145.             }
  146.         }
  147.     }
  148. }
复制代码

报纸
bbslover 在职认证  发表于 2016-9-3 20:34:16
thanks for sharing

地板
lm972 发表于 2016-9-12 08:32:28
谢谢分享

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
加好友,备注jltj
拉您入交流群
GMT+8, 2025-12-27 09:04