当前位置: 首页 > 工具软件 > joone > 使用案例 >

神经网络JOONE的实践

暨修洁
2023-12-01

什么是joone

  • Joone是一个免费的神经网络框架来创建,训练和测试人造神经网络。目标是为最热门的Java技术创造一个强大的环境,为热情和专业的用户。

  • Joone由一个中央引擎组成,这是Joone开发的所有应用程序的支点。Joone的神经网络可以建立在本地机器上,在分布式环境中进行培训,并在任何设备上运行。

  • 每个人都可以编写新的模块来实现从核心引擎分发的简单组件开始的新算法或新架构。主要思想是为围绕核心框架推出数百万个人工智能应用的基础。

    一些功能...

  • [x] 监督学习:
  • [x] 前馈神经网络(FFNN)
  • [x] 递归神经网络(Elman,Jordan,...)
  • [x] 时间延迟神经网络(TDNN)
  • [x] 标准背支(梯度下降,在线和批量)
  • [x] 弹性支撑(RPROP)

    无监督学习:

  • [x] Kohonen SOM(用WTA或高斯输出图)
  • [x] 主成分分析(PCA)
  • [x] 模块化神经网络(即混合所有上述架构的可能性)
  • [x] 强大的内置数据预处理机制
  • [x] 脚本功能(JavaScript),以便向NN添加自定义行为
  • 已有资料
  • [x] http://blog.csdn.net/u010223750/article/details/51334365 与官网资料大相迳庭,但是没有解决对照关系.
  • 资源链接
  • [x] http://www.jooneworld.com/download.html

基础例子异或运算

  • 可能大家会有疑问为什么还是异或运算,为什么不能换点其他的呢,进行异或运算数据量小,可以取得很好的学习效果.

  • 几个关键点:
  1. LinearLayer 单组神经网络 线性
  2. SigmoidLayer 单组神经网络 线性
  3. FullSynapse 连接神经网络的突触
  4. NeuralNet 神经网络容器
  5. Monitor 监视器 类似管理器

  • 基本运用流程
  1. 构造适用于模型的神经网络
  2. 根据模型控制的数据的流动控制i/o
  3. moniter开启学习过程
  4. 不得不注意: 单个类必须实现 NeuralNetListener

  • 展现编码:
    private NeuralNet neuralNet;
    private Monitor monitor;
    private SigmoidLayer out, hidden;
    private LinearLayer in;

    public static void main(String[] args) {
        Xor xor = new Xor();
        xor.init_nulnetwork();
        xor.dualData("res/xor.txt", "res/result.txt");
        // xor.name("res/xor.txt", "res/result.txt");
        try { Thread.sleep(1000); } catch (InterruptedException doNothing) { }
        xor.interrogate();
    }

    private void interrogate() {

        double[][] inputArray = new double[][] { { 0.0, 1.0 } };
        // set the inputs
        neuralNet.getMonitor().setLearning(false);

        MemoryInputSynapse inputSynapse = new MemoryInputSynapse();
        inputSynapse.setInputArray(inputArray);
        inputSynapse.setAdvancedColumnSelector("1,2");

        neuralNet.removeAllInputs();
        neuralNet.removeAllOutputs();
        neuralNet.addInputSynapse(inputSynapse);
        MemoryOutputSynapse memOut = new MemoryOutputSynapse();
        neuralNet.addOutputSynapse(memOut);
        if (neuralNet != null) {
            neuralNet.getMonitor().setSingleThreadMode(false);
            neuralNet.go();
            for (int i = 0; i < 4; i++) {
                double[] nextPattern = memOut.getNextPattern();
                System.out.println(nextPattern[0]);
            }
            System.exit(0);
        }
    }

    public void dualData(String inpath, String outpath) {

        // 输入数据突触
        FileInputSynapse inputSynapse = new FileInputSynapse();
        inputSynapse.setInputFile(new File(inpath));
        inputSynapse.setAdvancedColumnSelector("1,2");
        // 传入数据突触
        in.addInputSynapse(inputSynapse);

        // 训练突触
        TeachingSynapse Teaching = new TeachingSynapse();
        // 结果对应
        FileInputSynapse in_resultsynapse = new FileInputSynapse();
        in_resultsynapse.setInputFile(new File(inpath));
        in_resultsynapse.setAdvancedColumnSelector("3");

        // 期望结果
        Teaching.setDesired(in_resultsynapse);

        // 输出数据突触
        out.addOutputSynapse(Teaching);
        /* Creates the error output file */
        FileOutputSynapse error = new FileOutputSynapse();
        error.setFileName(outpath);
        // error.setBuffered(false);
        Teaching.addResultSynapse(error);
        neuralNet.setTeacher(Teaching);

        monitor.setLearning(true);
        monitor.setTrainingPatterns(4);
        monitor.setTotCicles(2000);
        neuralNet.go();
    }

    public void init_nulnetwork() {
        // 构造三个神经网络
        in = new LinearLayer("in");
        out = new SigmoidLayer("out");
        hidden = new SigmoidLayer("hidden");

        // 定义每个网络的神经数
        in.setRows(2);
        hidden.setRows(3);
        out.setRows(1);

        // 创建神经突触
        FullSynapse synapseone = new FullSynapse();
        FullSynapse synapsetwo = new FullSynapse();

        // 连接突触 in->hidden
        in.addOutputSynapse(synapseone);
        hidden.addInputSynapse(synapseone);
        // hidden>out
        out.addInputSynapse(synapsetwo);
        hidden.addOutputSynapse(synapsetwo);

        // 创建容器
        neuralNet = new NeuralNet();
        neuralNet.addLayer(in, NeuralNet.INPUT_LAYER);
        neuralNet.addLayer(out, NeuralNet.OUTPUT_LAYER);
        neuralNet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
        neuralNet.addNeuralNetListener(this);

        monitor = neuralNet.getMonitor();
        monitor.addNeuralNetListener(this);
        // 学习速度
        monitor.setLearningRate(0.8);
        // 学习梯度
        monitor.setMomentum(0.9);
    }

    @Override
    public void cicleTerminated(NeuralNetEvent arg0) {

    }

    @Override
    public void errorChanged(NeuralNetEvent arg0) {
        Monitor source = (Monitor) arg0.getSource();
        if (source.getCurrentCicle() % 100 == 0)
            System.out.println(source.getCurrentCicle() + " epochs remaining - RMSE = " + source.getGlobalError());
    }

    @Override
    public void netStarted(NeuralNetEvent arg0) {
        System.out.println("star ..............");
    }

    @Override
    public void netStopped(NeuralNetEvent arg0) {
    }

    @Override
    public void netStoppedError(NeuralNetEvent arg0, String arg1) {

    }

运行效果展示

  • 控制台打印内容
star ..............
1900 epochs remaining - RMSE = 0.47600682941789274
1800 epochs remaining - RMSE = 0.4193120542959407
1700 epochs remaining - RMSE = 0.3615907922936926
1600 epochs remaining - RMSE = 0.043037310196610556
1500 epochs remaining - RMSE = 0.029465401358726946
1400 epochs remaining - RMSE = 0.023733896256020553
1300 epochs remaining - RMSE = 0.020388084386553255
1200 epochs remaining - RMSE = 0.01813347388867849
1100 epochs remaining - RMSE = 0.0164837281607376
1000 epochs remaining - RMSE = 0.015209973541172709
900 epochs remaining - RMSE = 0.014188572240857322
800 epochs remaining - RMSE = 0.013346142747488585
700 epochs remaining - RMSE = 0.012636035607807258
600 epochs remaining - RMSE = 0.012026998316013436
500 epochs remaining - RMSE = 0.011497207806982808
400 epochs remaining - RMSE = 0.011030904955155142
300 epochs remaining - RMSE = 0.010616388884900461
200 epochs remaining - RMSE = 0.010244766308025021
100 epochs remaining - RMSE = 0.009909141567578732
  • 可以看出,上面列表中的第一行中的数字合理地接近于零。这是很好的,因为输入训练文件的第一行,如列表1所示,被预期结果是为零的。同样,第二行合理地接近于1,这也很好,因为训练文件的第二行被预期结果也是为1的。
  • 训练文件格式
0;0;0
1;0;1
1;1;0
0;1;1

设置的 setAdvancedColumnSelector 可以挑出来的需要的数据内容

存储问题

如果你在大量数据的训练,在每一次训练结束时,当前的结果不会在下一次启动学习时进行结果保存的.那么解决的办法是什么呢?

那肯定的java的序列化解决 你只需要的实现Serializable接口.

  • 对于joonetools工具类 只需要 使用load 方法载入即可.

  • 下面给出一个工具类例子:

package extend.java.Seriobject;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

/**
 * 
 * @author DGW-PC
 * 对象的序列化与反序列化
 */
public class TestSerializable {
    
    private final static String filepath="rouseces/temp.xor";
    public static void main(String[] args) {
        /*User user = new User(123, "dgw", "男");
        TestSerializable serializable = new TestSerializable();
        try {
            serializable.writeObject(user);
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }*/
        TestSerializable serializable = new TestSerializable();
        try {
            User object = (User)serializable.readObject();
            System.out.println(object.toString());
        } catch (ClassNotFoundException | IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

    }
    public void writeObject(Object o) throws IOException {
        File file = new File(filepath);
        if (file.exists()) {
            file.delete();
        }
        FileOutputStream outputStream = new FileOutputStream(file);
        ObjectOutputStream objectWrite = new ObjectOutputStream(outputStream);
        
        objectWrite.writeObject(o);
        objectWrite.close();
        outputStream.close();
    }
    
    @SuppressWarnings("resource")
    public Object  readObject() throws IOException, ClassNotFoundException {
        File file = new File(filepath);
        if (!file.exists()) {
            throw new FileNotFoundException();
        }
        FileInputStream inputStream = new FileInputStream(file);
        ObjectInputStream objectInputStream = new ObjectInputStream(inputStream);
        return objectInputStream.readObject();
    }
}

其他问题

实现NeuralNetListener 神经网络训练过程中的一些监听函数:

  1. icleTerminated:每个循环结束后输出的信息
  2. errorChanged:神经网络错误率变化时候输出的信息
  3. netStarted:神经网络开始运行的时候输出的信息
  4. netStopped:神经网络停止的时候输出的信息

验证代码:

private void interrogate() {

        double[][] inputArray = new double[][] { { 0.0, 1.0 } };
        // 设置输入
        neuralNet.getMonitor().setLearning(false);

        MemoryInputSynapse inputSynapse = new MemoryInputSynapse();
        inputSynapse.setInputArray(inputArray);
        inputSynapse.setAdvancedColumnSelector("1,2");

        neuralNet.removeAllInputs();
        neuralNet.removeAllOutputs();
        neuralNet.addInputSynapse(inputSynapse);
        MemoryOutputSynapse memOut = new MemoryOutputSynapse();
        neuralNet.addOutputSynapse(memOut);
        if (neuralNet != null) {
            neuralNet.getMonitor().setSingleThreadMode(false);
            neuralNet.go();
            for (int i = 0; i < 4; i++) {
                double[] nextPattern = memOut.getNextPattern();
                System.out.println(nextPattern[0]);
            }
            System.exit(0);
        }
    }
  • 当输入0,1输出内容
100 epochs remaining - RMSE = 0.009909141567578732
[main] [WARN] - org.joone.net.NeuralNet - Termination requested but net appears not to be running.
star ..............
0.9908896675722858
  • 当输入0,0 是输出的内容
400 epochs remaining - RMSE = 0.010399344820751158
300 epochs remaining - RMSE = 0.010045322272865264
200 epochs remaining - RMSE = 0.009724485002708847
100 epochs remaining - RMSE = 0.009431977897288895
[main] [WARN] - org.joone.net.NeuralNet - Termination requested but net appears not to be running.
star ..............
0.008214361744075686

由上面的两个数据内容比较可以知道 在进行4个数据比较过后 ,达到了学习的预期的目的,0.0表示不成立结果为0 ,0.99 表示结果成立,结果为1


封装joonetools工具类的例子

  • 更多的验证器例子:请参照官方api文档
package org.basics.code;

import org.joone.helpers.factory.JooneTools;
import org.joone.net.NeuralNet;

public class Tools {
      // XOR 输入值
    private static double[][]   inputArray = new double[][] {
        {0.0, 0.0},
        {0.0, 1.0},
        {1.0, 0.0},
        {1.0, 1.0}
    };
    
    // XOR 期望值
    private static double[][]   desiredArray = new double[][] {
        {0.0},
        {1.0},
        {1.0},
        {0.0}
    };
    
    public static void main(String[] args) {
        try {
            // 创建3个sigmoid神经网络,序列为2 2 1 
            NeuralNet nnet = JooneTools.create_standard(new int[]{ 2, 2, 1 }, JooneTools.LOGISTIC);
                nnet.getMonitor().setSingleThreadMode(true);
                /**
                 * 参数意义:
                 * 1. 网络容器
                 * 2. 输入,期末数组
                 * 3.5000 个递归梯度
                 * 4. 步进为0.01
                 * 5. 形式为输出
                 * 6.异步模式
                 */
                double rmse = JooneTools.train(nnet, inputArray, desiredArray,
                        5000, 0.01,
                        200, System.out, false);
                
                // 避免的时间片重叠
                try { Thread.sleep(50); } catch (InterruptedException doNothing) { }
                
                // 更直观的验证列表
                System.out.println("Last RMSE = "+rmse);
                System.out.println("\nResults:");
                System.out.println("|Inp 1\t|Inp 2\t|Output");
                for (int i=0; i < 4; ++i) {
                    double[] output = JooneTools.interrogate(nnet, inputArray[i]);
                    System.out.print("| "+inputArray[i][0]+"\t| "+inputArray[i][1]+"\t| ");
                    System.out.println(output[0]);
                }
                
                // 测试结果
                double testRMSE = JooneTools.test(nnet, inputArray, desiredArray);
                System.out.println("\nTest error = "+testRMSE);
            } catch (Exception exc) { exc.printStackTrace(); }
    }

}

转载于:https://www.cnblogs.com/dgwblog/p/7711141.html

 类似资料: