楼主: Lisrelchen
1782 2

【独家发布】NeuralNetwork using Weka(Java) [推广有奖]

  • 0关注
  • 62粉丝

VIP

已卖:4194份资源

院士

67%

还不是VIP/贵宾

-

TA的文库  其他...

Bayesian NewOccidental

Spatial Data Analysis

东西方数据挖掘

威望
0
论坛币
50288 个
通用积分
83.6306
学术水平
253 点
热心指数
300 点
信用等级
208 点
经验
41518 点
帖子
3256
精华
14
在线时间
766 小时
注册时间
2006-5-4
最后登录
2022-11-6

楼主
Lisrelchen 发表于 2016-9-3 02:39:04 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币
  1. package amten.ml.examples;

  2. import amten.ml.NNParams;
  3. import amten.ml.matrix.Matrix;
  4. import amten.ml.matrix.MatrixUtils;

  5. /**
  6. * Examples of using NeuralNetwork for classification.
  7. *
  8. * @author Johannes Amtén
  9. */
  10. public class NNClassificationExample {

  11.     /**
  12.      * Performs classification of Handwritten digits,
  13.      * using a subset (1000 rows) from the Kaggle Digits competition.
  14.      * <br></br>
  15.      * Uses file /example_data/Kaggle_Digits_1000.csv
  16.      *
  17.      * @see <a href="http://www.kaggle.com/c/digit-recognizer">http://www.kaggle.com/c/digit-recognizer</a></a>
  18.      */
  19.     public static void runKaggleDigitsClassification(boolean useConvolution) throws Exception {
  20.         if (useConvolution) {
  21.             System.out.println("Running classification on Kaggle Digits dataset, with convolution...\n");
  22.         } else {
  23.             System.out.println("Running classification on Kaggle Digits dataset...\n");
  24.         }
  25.         // Read data from CSV-file
  26.         int headerRows = 1;
  27.         char separator = ',';
  28.         Matrix data = MatrixUtils.readCSV("example_data/Kaggle_Digits_1000.csv", separator, headerRows);

  29.         // Split data into training set and crossvalidation set.
  30.         float crossValidationPercent = 33;
  31.         Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0);
  32.         Matrix dataTrain = split[0];
  33.         Matrix dataCV = split[1];

  34.         // First column contains the classification label. The rest are the indata.
  35.         Matrix xTrain = dataTrain.getColumns(1, -1);
  36.         Matrix yTrain = dataTrain.getColumns(0, 0);
  37.         Matrix xCV = dataCV.getColumns(1, -1);
  38.         Matrix yCV = dataCV.getColumns(0, 0);

  39.         NNParams params = new NNParams();
  40.         params.numClasses = 10; // 10 digits to classify
  41.         params.hiddenLayerParams = useConvolution ? new NNParams.NNLayerParams[]{ new NNParams.NNLayerParams(20, 5, 5, 2, 2) , new NNParams.NNLayerParams(100, 5, 5, 2, 2) } :
  42.                                                     new NNParams.NNLayerParams[] { new NNParams.NNLayerParams(100) };
  43.         params.maxIterations = useConvolution ? 10 : 200;
  44.         params.learningRate = useConvolution ? 1E-2 : 0;

  45.         long startTime = System.currentTimeMillis();
  46.         amten.ml.NeuralNetwork nn = new amten.ml.NeuralNetwork(params);
  47.         nn.train(xTrain, yTrain);
  48.         System.out.println("\nTraining time: " + String.format("%.3g", (System.currentTimeMillis() - startTime) / 1000.0) + "s");

  49.         int[] predictedClasses = nn.getPredictedClasses(xTrain);
  50.         int correct = 0;
  51.         for (int i = 0; i < predictedClasses.length; i++) {
  52.             if (predictedClasses[i] == yTrain.get(i, 0)) {
  53.                 correct++;
  54.             }
  55.         }
  56.         System.out.println("Training set accuracy: " + String.format("%.3g", (double) correct/predictedClasses.length*100) + "%");

  57.         predictedClasses = nn.getPredictedClasses(xCV);
  58.         correct = 0;
  59.         for (int i = 0; i < predictedClasses.length; i++) {
  60.             if (predictedClasses[i] == yCV.get(i, 0)) {
  61.                 correct++;
  62.             }
  63.         }
  64.         System.out.println("Crossvalidation set accuracy: " + String.format("%.3g", (double) correct/predictedClasses.length*100) + "%");
  65.     }

  66.     /**
  67.      * Performs classification of titanic survivors/casualties,
  68.      * using a cleaned dataset from the Kaggle Digits competition.
  69.      * <br></br>
  70.      * Dataset have been cleaned by removing some string attributes,
  71.      * converting some string attributes to nominal (replacing string values with numeric indexes)
  72.      * and by filling in missing values with mean/mode values.
  73.      * <br></br>
  74.      * Uses file /example_data/Kaggle_Titanic_cleaned.csv
  75.      *
  76.      * @see <a href="http://www.kaggle.com/c/titanic-gettingStarted">http://www.kaggle.com/c/titanic-gettingStarted</a></a>
  77.      */
  78.     public static void runKaggleTitanicClassification() throws Exception {
  79.         System.out.println("Running classification on Kaggle Titanic dataset...\n");
  80.         // Read data from CSV-file
  81.         int headerRows = 1;
  82.         char separator = ',';
  83.         Matrix data = MatrixUtils.readCSV("example_data/Kaggle_Titanic_Cleaned.csv", separator, headerRows);

  84.         // Split data into training set and crossvalidation set.
  85.         float crossValidationPercent = 33;
  86.         Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0);
  87.         Matrix dataTrain = split[0];
  88.         Matrix dataCV = split[1];

  89.         // First column contains the classification label. The rest are the indata.
  90.         Matrix xTrain = dataTrain.getColumns(1, -1);
  91.         Matrix yTrain = dataTrain.getColumns(0, 0);
  92.         Matrix xCV = dataCV.getColumns(1, -1);
  93.         Matrix yCV = dataCV.getColumns(0, 0);

  94.         NNParams params = new NNParams();
  95.         // Pclass has 3 categories
  96.         // Sex has 2 categories
  97.         // Embarked has 3 categories
  98.         // The rest of the attributes are numeric (as indicated with "1").
  99.         params.numCategories = new int[]  {3, 2, 1, 1, 1, 1, 3};
  100.         params.numClasses = 2; // 2 classes, survived/not

  101.         long startTime = System.currentTimeMillis();
  102.         amten.ml.NeuralNetwork nn = new amten.ml.NeuralNetwork(params);
  103.         nn.train(xTrain, yTrain);
  104.         System.out.println("\nTraining time: " + String.format("%.3g", (System.currentTimeMillis() - startTime) / 1000.0) + "s");

  105.         int[] predictedClasses = nn.getPredictedClasses(xTrain);
  106.         int correct = 0;
  107.         for (int i = 0; i < predictedClasses.length; i++) {
  108.             if (predictedClasses[i] == yTrain.get(i, 0)) {
  109.                 correct++;
  110.             }
  111.         }
  112.         System.out.println("Training set accuracy: " + String.format("%.3g", (double) correct/predictedClasses.length*100) + "%");

  113.         predictedClasses = nn.getPredictedClasses(xCV);
  114.         correct = 0;
  115.         for (int i = 0; i < predictedClasses.length; i++) {
  116.             if (predictedClasses[i] == yCV.get(i, 0)) {
  117.                 correct++;
  118.             }
  119.         }
  120.         System.out.println("Crossvalidation set accuracy: " + String.format("%.3g", (double) correct/predictedClasses.length*100) + "%");
  121.     }

  122.     public static void main(String[] args) throws Exception {
  123.         runKaggleDigitsClassification(false);
  124.         System.out.println("\n\n\n");
  125.         runKaggleDigitsClassification(true);
  126.         System.out.println("\n\n\n");
  127.         runKaggleTitanicClassification();
  128.     }
  129. }
复制代码
What's this?

Java (convolutional or fully-connected) neural network implementation with plugin for Weka. Uses dropout and rectified linear units. Implementation is multithreaded and uses MTJ matrix library with native libs for performance.

InstallationWeka

Go to https://github.com/amten/NeuralNetwork/releases/latest to find the latest release. Download the files NeuralNetwork.zip and BLAS-dlls.zip. In Weka, go to Tools/Package Manager and press the "File/URL" button. Browse to the NeuralNetwork.zip file and press "ok".

Important! For optimal performance, you need to install native matrix library files.
Windows: Unzip the BLAS-dlls.zip file to Wekas install dir (".../Program Files/Weka-3-7").
Ubuntu: Run "sudo apt-get install libatlas3-base libopenblas-base" in a terminal window.

Standalone

This package was made mainly to be used from the Weka UI, but it can be used in your own java code as well.

Go to https://github.com/amten/NeuralNetwork/releases/latest to find the latest release. Download the file NeuralNetwork.zip and unzip.

Include the files NeuralNetwork.jar, lib/mtj-1.0-snapshot.jar, lib/opencsv-2.3.jar in your classpath.

Important! For optimal performance, you need to install native matrix library files.
Windows: Unzip the BLAS-dlls.zip file to the directory where you execute your application, or any other directory in the PATH.
Ubuntu: Run "sudo apt-get install libatlas3-base libopenblas-base" in a terminal window.

UsageWeka

In Weka, you will find the classifier under classifiers/functions/NeuralNetwork. For explanations of the settings, click the "more" button.

Note 1: If you start Weka with console (alternative available in the windows start menu), you will get printouts of cost during each iteration of training and you can press enter in the console window to halt the training.

Note 2: When using dropout as regularization, it might still be a good idea to keep a small weight penalty. This keeps weights from exploding and causing overflows.

Note 3: When using convolutional layers, it seems to be most efficient to use batch-size=1 (i.e. Stochastic Gradient Descent)

Standalone

Example code showing classification and regression can be found here:https://github.com/amten/NeuralNetwork/tree/master/src/amten/ml/examples

License

Free to copy and modify. Please include author name if you copy code.


二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:network Neural Using WEKA Java Johannes package public import Java

本帖被以下文库推荐

沙发
Lisrelchen 发表于 2016-9-3 02:42:05
  1. package amten.ml.examples;

  2. import amten.ml.NNParams;
  3. import amten.ml.matrix.Matrix;
  4. import amten.ml.matrix.MatrixUtils;

  5. /**
  6. * Examples of using NeuralNetwork for classification.
  7. *
  8. * @author Johannes Amtén
  9. */
  10. public class NNClassificationExample {

  11.     /**
  12.      * Performs classification of Handwritten digits,
  13.      * using a subset (1000 rows) from the Kaggle Digits competition.
  14.      * <br></br>
  15.      * Uses file /example_data/Kaggle_Digits_1000.csv
  16.      *
  17.      * @see <a href="http://www.kaggle.com/c/digit-recognizer">http://www.kaggle.com/c/digit-recognizer</a></a>
  18.      */
  19.     public static void runKaggleDigitsClassification(boolean useConvolution) throws Exception {
  20.         if (useConvolution) {
  21.             System.out.println("Running classification on Kaggle Digits dataset, with convolution...\n");
  22.         } else {
  23.             System.out.println("Running classification on Kaggle Digits dataset...\n");
  24.         }
  25.         // Read data from CSV-file
  26.         int headerRows = 1;
  27.         char separator = ',';
  28.         Matrix data = MatrixUtils.readCSV("example_data/Kaggle_Digits_1000.csv", separator, headerRows);

  29.         // Split data into training set and crossvalidation set.
  30.         float crossValidationPercent = 33;
  31.         Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0);
  32.         Matrix dataTrain = split[0];
  33.         Matrix dataCV = split[1];

  34.         // First column contains the classification label. The rest are the indata.
  35.         Matrix xTrain = dataTrain.getColumns(1, -1);
  36.         Matrix yTrain = dataTrain.getColumns(0, 0);
  37.         Matrix xCV = dataCV.getColumns(1, -1);
  38.         Matrix yCV = dataCV.getColumns(0, 0);

  39.         NNParams params = new NNParams();
  40.         params.numClasses = 10; // 10 digits to classify
  41.         params.hiddenLayerParams = useConvolution ? new NNParams.NNLayerParams[]{ new NNParams.NNLayerParams(20, 5, 5, 2, 2) , new NNParams.NNLayerParams(100, 5, 5, 2, 2) } :
  42.                                                     new NNParams.NNLayerParams[] { new NNParams.NNLayerParams(100) };
  43.         params.maxIterations = useConvolution ? 10 : 200;
  44.         params.learningRate = useConvolution ? 1E-2 : 0;

  45.         long startTime = System.currentTimeMillis();
  46.         amten.ml.NeuralNetwork nn = new amten.ml.NeuralNetwork(params);
  47.         nn.train(xTrain, yTrain);
  48.         System.out.println("\nTraining time: " + String.format("%.3g", (System.currentTimeMillis() - startTime) / 1000.0) + "s");

  49.         int[] predictedClasses = nn.getPredictedClasses(xTrain);
  50.         int correct = 0;
  51.         for (int i = 0; i < predictedClasses.length; i++) {
  52.             if (predictedClasses[i] == yTrain.get(i, 0)) {
  53.                 correct++;
  54.             }
  55.         }
  56.         System.out.println("Training set accuracy: " + String.format("%.3g", (double) correct/predictedClasses.length*100) + "%");

  57.         predictedClasses = nn.getPredictedClasses(xCV);
  58.         correct = 0;
  59.         for (int i = 0; i < predictedClasses.length; i++) {
  60.             if (predictedClasses[i] == yCV.get(i, 0)) {
  61.                 correct++;
  62.             }
  63.         }
  64.         System.out.println("Crossvalidation set accuracy: " + String.format("%.3g", (double) correct/predictedClasses.length*100) + "%");
  65.     }

  66.     /**
  67.      * Performs classification of titanic survivors/casualties,
  68.      * using a cleaned dataset from the Kaggle Digits competition.
  69.      * <br></br>
  70.      * Dataset have been cleaned by removing some string attributes,
  71.      * converting some string attributes to nominal (replacing string values with numeric indexes)
  72.      * and by filling in missing values with mean/mode values.
  73.      * <br></br>
  74.      * Uses file /example_data/Kaggle_Titanic_cleaned.csv
  75.      *
  76.      * @see <a href="http://www.kaggle.com/c/titanic-gettingStarted">http://www.kaggle.com/c/titanic-gettingStarted</a></a>
  77.      */
  78.     public static void runKaggleTitanicClassification() throws Exception {
  79.         System.out.println("Running classification on Kaggle Titanic dataset...\n");
  80.         // Read data from CSV-file
  81.         int headerRows = 1;
  82.         char separator = ',';
  83.         Matrix data = MatrixUtils.readCSV("example_data/Kaggle_Titanic_Cleaned.csv", separator, headerRows);

  84.         // Split data into training set and crossvalidation set.
  85.         float crossValidationPercent = 33;
  86.         Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0);
  87.         Matrix dataTrain = split[0];
  88.         Matrix dataCV = split[1];

  89.         // First column contains the classification label. The rest are the indata.
  90.         Matrix xTrain = dataTrain.getColumns(1, -1);
  91.         Matrix yTrain = dataTrain.getColumns(0, 0);
  92.         Matrix xCV = dataCV.getColumns(1, -1);
  93.         Matrix yCV = dataCV.getColumns(0, 0);

  94.         NNParams params = new NNParams();
  95.         // Pclass has 3 categories
  96.         // Sex has 2 categories
  97.         // Embarked has 3 categories
  98.         // The rest of the attributes are numeric (as indicated with "1").
  99.         params.numCategories = new int[]  {3, 2, 1, 1, 1, 1, 3};
  100.         params.numClasses = 2; // 2 classes, survived/not

  101.         long startTime = System.currentTimeMillis();
  102.         amten.ml.NeuralNetwork nn = new amten.ml.NeuralNetwork(params);
  103.         nn.train(xTrain, yTrain);
  104.         System.out.println("\nTraining time: " + String.format("%.3g", (System.currentTimeMillis() - startTime) / 1000.0) + "s");

  105.         int[] predictedClasses = nn.getPredictedClasses(xTrain);
  106.         int correct = 0;
  107.         for (int i = 0; i < predictedClasses.length; i++) {
  108.             if (predictedClasses[i] == yTrain.get(i, 0)) {
  109.                 correct++;
  110.             }
  111.         }
  112.         System.out.println("Training set accuracy: " + String.format("%.3g", (double) correct/predictedClasses.length*100) + "%");

  113.         predictedClasses = nn.getPredictedClasses(xCV);
  114.         correct = 0;
  115.         for (int i = 0; i < predictedClasses.length; i++) {
  116.             if (predictedClasses[i] == yCV.get(i, 0)) {
  117.                 correct++;
  118.             }
  119.         }
  120.         System.out.println("Crossvalidation set accuracy: " + String.format("%.3g", (double) correct/predictedClasses.length*100) + "%");
  121.     }

  122.     public static void main(String[] args) throws Exception {
  123.         runKaggleDigitsClassification(false);
  124.         System.out.println("\n\n\n");
  125.         runKaggleDigitsClassification(true);
  126.         System.out.println("\n\n\n");
  127.         runKaggleTitanicClassification();
  128.     }
  129. }
复制代码

藤椅
Lisrelchen 发表于 2016-9-3 02:42:30
  1. package amten.ml.examples;

  2. import amten.ml.NNParams;
  3. import amten.ml.matrix.Matrix;
  4. import amten.ml.matrix.MatrixUtils;

  5. /**
  6. * Examples of using NeuralNetwork for regression.
  7. *
  8. * @author Johannes Amtén
  9. */
  10. public class NNRegressionExample {

  11.     /**
  12.      * Performs regression on a dataset of car prices for cars with different features.
  13.      * <br></br>
  14.      * Uses file /example_data/Car_Prices.csv
  15.      */
  16.     public static void runCarPricesRegression() throws Exception {
  17.         System.out.println("Running regression on Car Prices dataset...\n");
  18.         // Read data from CSV-file
  19.         int headerRows = 1;
  20.         char separator = ',';
  21.         Matrix data = MatrixUtils.readCSV("example_data/Car_Prices.csv", separator, headerRows);

  22.         // Split data into training set and crossvalidation set.
  23.         float crossValidationPercent = 33;
  24.         Matrix[] split = MatrixUtils.split(data, crossValidationPercent, 0);
  25.         Matrix dataTrain = split[0];
  26.         Matrix dataCV = split[1];

  27.         // 15:th column contains the correct price. The rest are the indata.
  28.         Matrix xTrain = dataTrain.getColumns(0, 13);
  29.         Matrix yTrain = dataTrain.getColumns(14, 14);
  30.         Matrix xCV = dataCV.getColumns(0, 13);
  31.         Matrix yCV = dataCV.getColumns(14, 14);

  32.         // Use default parameters; single hidden layer with 100 units.
  33.         NNParams params = new NNParams();

  34.         long startTime = System.currentTimeMillis();
  35.         amten.ml.NeuralNetwork nn = new amten.ml.NeuralNetwork(params);
  36.         nn.train(xTrain, yTrain);
  37.         System.out.println("\nTraining time: " + String.format("%.3g", (System.currentTimeMillis() - startTime) / 1000.0) + "s");

  38.         Matrix predictions = nn.getPredictions(xTrain);
  39.         double error = 0;
  40.         for (int i = 0; i < predictions.numRows(); i++) {
  41.             error += Math.pow(predictions.get(i, 0) - yTrain.get(i, 0), 2);
  42.         }
  43.         error = Math.sqrt(error / predictions.numRows());
  44.         System.out.println("Training set root mean squared error: " + String.format("%.4g", error));

  45.         predictions = nn.getPredictions(xCV);
  46.         error = 0;
  47.         for (int i = 0; i < predictions.numRows(); i++) {
  48.             error += Math.pow(predictions.get(i, 0) - yCV.get(i, 0), 2);
  49.         }
  50.         error = Math.sqrt(error / predictions.numRows());
  51.         System.out.println("Crossvalidation set root mean squared error: " + String.format("%.4g", error));
  52.     }


  53.     public static void main(String[] args) throws Exception {
  54.         runCarPricesRegression();
  55.     }
  56. }
复制代码

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
jg-xs1
拉您进交流群
GMT+8, 2025-12-31 14:36