当前位置: 首页 > 工具软件 > Neuroph > 使用案例 >

Neuroph多层感知机训练XOR

卢鸿博
2023-12-01

        之前的实验中只有输入输出两层的简单感知机,可以通过对几条逻辑运算的输入输出情况进行训练,就能完成AND和OR的学习,但是无法对XOR运算在有限次迭代内完成训练。需要我们采用多层的网络结构,其实三层就够了,在原来的基础上,增加一个隐层。依旧用在java代码中运用Neuroph框架实现一下。

import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.util.TransferFunctionType;

import java.util.Arrays;

/**
 * Created by 阿龙 on 2017/1/25.
 */
//通过实现LearningEventListener可以在handle里打印每次迭代的情况
public class CalculateXOR implements LearningEventListener{
    public static void main(String args[]){
        new CalculateXOR().learnXorCalc();
    }

    public void learnXorCalc(){
        // XOR运算训练集
        DataSet trainingSet = new DataSet(2, 1);
        trainingSet.addRow(new DataSetRow(new double[]{0, 0}, new double[]{0}));
        trainingSet.addRow(new DataSetRow(new double[]{0, 1}, new double[]{1}));
        trainingSet.addRow(new DataSetRow(new double[]{1, 0}, new double[]{1}));
        trainingSet.addRow(new DataSetRow(new double[]{1, 1}, new double[]{0}));

        // 转移函数采用sigmoid,也可以用tanh之类的
        // 三个层,输入错两个单元,隐层3个单元,输出层1个单元
        MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 2, 3, 1);
        //反向误差传播
        myMlPerceptron.setLearningRule(new BackPropagation());

        LearningRule learningRule = myMlPerceptron.getLearningRule();
        learningRule.addListener(this);

        // 训练XOR集
        System.out.println("XOR集训练神经网络...");
        myMlPerceptron.learn(trainingSet);

        // 用训练集测试结果
        System.out.println("测试结果");
        testNeuralNetwork(myMlPerceptron, trainingSet);

    }
//    测试方法
    public static void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {

        for(DataSetRow testSetRow : testSet.getRows()) {
            neuralNet.setInput(testSetRow.getInput());
            neuralNet.calculate();
            double[] networkOutput = neuralNet.getOutput();

            System.out.print("Input: " + Arrays.toString( testSetRow.getInput() ) );
            System.out.println(" Output: " + Arrays.toString( networkOutput) );
        }
    }
    @Override
    public void handleLearningEvent(LearningEvent event) {
        BackPropagation bp = (BackPropagation)event.getSource();
        if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED)
            System.out.println(bp.getCurrentIteration() + ". iteration : "+ bp.getTotalNetworkError());
    }
}


运行结果:

3319. iteration : 0.010150142293472076
3320. iteration : 0.010133211345210417
3321. iteration : 0.010116321172199697
3322. iteration : 0.010099471654575976
3323. iteration : 0.010082662672854254
3324. iteration : 0.010065894107927189
3325. iteration : 0.010049165841064125
3326. iteration : 0.010032477753910163
3327. iteration : 0.010015829728484926
3328. iteration : 0.009999221647181653
测试结果
Input: [0.0, 0.0] Output: [0.14221921066664928]
Input: [0.0, 1.0] Output: [0.867173897293698]
Input: [1.0, 0.0] Output: [0.8634850747201129]
Input: [1.0, 1.0] Output: [0.14860130685438494]

Process finished with exit code 0


可以看到经过三千多次迭代误差控制到0.01以下,测试的结果也都准确地偏向0或1.

上面我在隐层中设置的神经元数量为3

	// 转移函数采用sigmoid,也可以用tanh之类的
        // 三个层,输入错两个单元,隐层3个单元,输出层1个单元
        MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 2, 3, 1);

有一个经验值,隐层的神经元数量大致在输入元数量n的2n+1左右效果为好,那我将隐层单元数设为5看看。

	// 转移函数采用sigmoid,也可以用tanh之类的
        // 三个层,输入错两个单元,隐层5个单元,输出层1个单元
        MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 2, 5, 1);

运行结果:

3507. iteration : 0.010068572832861864
3508. iteration : 0.010052969077009219
3509. iteration : 0.010037397803050055
3510. iteration : 0.010021858931944008
3511. iteration : 0.010006352384852639
3512. iteration : 0.009990878083139006
测试结果
Input: [0.0, 0.0] Output: [0.11724352230560249]
Input: [0.0, 1.0] Output: [0.8418010439963585]
Input: [1.0, 0.0] Output: [0.8766007891447581]
Input: [1.0, 1.0] Output: [0.1566270773180774]

迭代次数多了,变化不大,稍微精确了一点点。




 类似资料: