JAVA使用JOONE实现神经网络的官网例子
import org.joone.engine.*;
import org.joone.engine.learning.*;
import org.joone.io.*;
import org.joone.net.*;
/**
* JOONE神经网络的测试学习类
*
*/
public class XOR_using_NeuralNet implements NeuralNetListener{
private NeuralNet nnet = null;
private MemoryInputSynapse inputSynapse, desiredOutputSynapse;
LinearLayer input;
SigmoidLayer hidden, output;
boolean singleThreadMode = true;
/**
* XOR input
*/
private double[][] inputArray = new double[][]{
{ 0.0, 0.0 },
{ 0.0, 1.0 },
{ 1.0, 0.0 },
{ 1.0, 1.0 }
};
/**
* XOR desired output
*/
private double[][] desiredOutputArray = new double[][]{
{ 0.0 },
{ 1.0 },
{ 1.0 },
{ 1.0 }
};
/**
* @param args the command line arguments
*/
public static void main(String args[]){
XOR_using_NeuralNet xor = new XOR_using_NeuralNet();
xor.initNeuralNet();
xor.train();
xor.interrogate();
}
/**
* Method declaration
*/
public void train(){
// set the inputs
inputSynapse.setInputArray(inputArray);
inputSynapse.setAdvancedColumnSelector(" 1,2 ");
// set the desired outputs
desiredOutputSynapse.setInputArray(desiredOutputArray);
desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");
// get the monitor object to train or feed forward
Monitor monitor = nnet.getMonitor();
// set the monitor parameters创建监视器对象并且设置学习参数
monitor.setLearningRate(0.8);
monitor.setMomentum(0.3);
monitor.setTrainingPatterns(inputArray.length);
monitor.setTotCicles(5000);
monitor.setLearning(true);
long initms = System.currentTimeMillis();
// Run the network in single-thread, synchronized mode
nnet.getMonitor().setSingleThreadMode(singleThreadMode);
nnet.go(true);
System.out.println(" Total time= "
+ (System.currentTimeMillis() - initms) + " ms ");
}
private void interrogate(){
double[][] inputArray = new double[][]{{ 1.0, 1.0 }};
// set the inputs
inputSynapse.setInputArray(inputArray);
inputSynapse.setAdvancedColumnSelector(" 1,2 ");
Monitor monitor = nnet.getMonitor();
monitor.setTrainingPatterns(4);
monitor.setTotCicles(1);
monitor.setLearning(false);
MemoryOutputSynapse memOut = new MemoryOutputSynapse();
// set the output synapse to write the output of the net
if (nnet != null){
nnet.addOutputSynapse(memOut);
System.out.println(nnet.check());
nnet.getMonitor().setSingleThreadMode(singleThreadMode);
nnet.go();
for (int i = 0; i < 4; i++){
double[] pattern = memOut.getNextPattern();
System.out.println(" Output pattern # " + (i + 1) + " = " + pattern[0]);
}
System.out.println(" Interrogating Finished ");
}
}
/**
* Method declaration
*/
protected void initNeuralNet(){
// First create the three layers首先,创造这三个层
input = new LinearLayer();
hidden = new SigmoidLayer();
output = new SigmoidLayer();
// set the dimensions of the layers指定在每一层中的"行"号。该"行"号对应于这一层中的神经原的数目。
input.setRows(2);
hidden.setRows(3);
output.setRows(1);
//每一层被赋于一个名字
input.setLayerName(" L.input ");
hidden.setLayerName(" L.hidden ");
output.setLayerName(" L.output ");
// Now create the two Synapses
FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn.输入-> 隐蔽的连接 */
FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn.隐蔽-> 输出连接 */
// Connect the input layer whit the hidden layer联接输入层到隐蔽层
input.addOutputSynapse(synapse_IH);
hidden.addInputSynapse(synapse_IH);
// Connect the hidden layer whit the output layer联接隐蔽层到输出层
hidden.addOutputSynapse(synapse_HO);
output.addInputSynapse(synapse_HO);
// the input to the neural net
inputSynapse = new MemoryInputSynapse();
input.addInputSynapse(inputSynapse);
// The Trainer and its desired output
desiredOutputSynapse = new MemoryInputSynapse();
TeachingSynapse trainer = new TeachingSynapse();
trainer.setDesired(desiredOutputSynapse);
// Now we add this structure to a NeuralNet object
nnet = new NeuralNet();
nnet.addLayer(input, NeuralNet.INPUT_LAYER);
nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);
nnet.setTeacher(trainer);
output.addOutputSynapse(trainer);
nnet.addNeuralNetListener(this);
}
public void cicleTerminated(NeuralNetEvent e){
}
public void errorChanged(NeuralNetEvent e){
Monitor mon = (Monitor) e.getSource();
if (mon.getCurrentCicle() % 100 == 0)
System.out.println(" Epoch: "
+ (mon.getTotCicles() - mon.getCurrentCicle()) + " RMSE: "
+ mon.getGlobalError());
}
public void netStarted(NeuralNetEvent e){
Monitor mon = (Monitor) e.getSource();
System.out.print(" Network started for ");
if (mon.isLearning())
System.out.println(" training. ");
else
System.out.println(" interrogation. ");
}
public void netStopped(NeuralNetEvent e){
Monitor mon = (Monitor) e.getSource();
System.out.println(" Network stopped. Last RMSE= "
+ mon.getGlobalError());
}
public void netStoppedError(NeuralNetEvent e, String error){
System.out.println(" Network stopped due the following error: "
+ error);
}
}
SEO外链