当前位置: 首页 > 知识库问答 >
问题:

Deeplearning4J-如何为大数据迭代多个数据集?

钱振
2023-03-14

我正在学习用于构建神经网络的Deeplearning4j(Ver.1.0.0-M1.1)。

我使用Deeplearning4j的IrisClassifier作为一个例子,它工作得很好:

//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File(DownloaderUtility.IRISDATA.Download(),"iris.txt")));

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();

//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData);     //Apply normalization to the training data
normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

final int numInputs = 4;
int outputNum = 3;
long seed = 6;

log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(seed)
    .activation(Activation.TANH)
    .weightInit(WeightInit.XAVIER)
    .updater(new Sgd(0.1))
    .l2(1e-4)
    .list()
    .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
        .build())
    .layer(new DenseLayer.Builder().nIn(3).nOut(3)
        .build())
    .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
        .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
        .nIn(3).nOut(outputNum).build())
    .build();

//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//record score once every 100 iterations
model.setListeners(new ScoreIterationListener(100));

for(int i=0; i<1000; i++ ) {
    model.fit(trainingData);
}

//evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());

eval.eval(testData.getLabels(), output);
log.info(eval.stats());

对于我的项目,我输入了大约30000条记录(在iris示例-150中)。每个记录是一个矢量大小~7000(在iris示例-4中)。

显然,我不能在一个数据集中处理整个数据--这将为JVM产生OOM。

如何处理多个数据集中的数据?

...
    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
    List<DataSet> trainingData = new ArrayList<>();
    List<DataSet> testData = new ArrayList<>();

    while (iterator.hasNext()) {
        DataSet allData = iterator.next();
        allData.shuffle();
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training
        trainingData.add(testAndTrain.getTrain());
        testData.add(testAndTrain.getTest());
    }
    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    for (DataSet dataSetTraining : trainingData) {
        normalizer.fit(dataSetTraining);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
        normalizer.transform(dataSetTraining);     //Apply normalization to the training data
    }
    for (DataSet dataSetTest : testData) {
        normalizer.transform(dataSetTest);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
    }

...

    for(int i=0; i<1000; i++ ) {
        for (DataSet dataSetTraining : trainingData) {
            model.fit(dataSetTraining);
        }
    }
Exception in thread "main" java.lang.NullPointerException: Cannot read field "javaShapeInformation" because "this.jvmShapeInfo" is null
    at org.nd4j.linalg.api.ndarray.BaseNDArray.dataType(BaseNDArray.java:5507)
    at org.nd4j.linalg.api.ndarray.BaseNDArray.validateNumericalArray(BaseNDArray.java:5575)
    at org.nd4j.linalg.api.ndarray.BaseNDArray.add(BaseNDArray.java:3087)
    at com.aarcapital.aarmlclassifier.classification.FAClassifierLearning.main(FAClassifierLearning.java:117)
    Evaluation eval = new Evaluation(26);

    INDArray output = new NDArray();
    for (DataSet dataSetTest : testData) {
        output.add(model.output(dataSetTest.getFeatures())); // ERROR HERE
    }

    System.out.println("--- Output ---");
    System.out.println(output);

    INDArray labels = new NDArray();
    for (DataSet dataSetTest : testData) {
        labels.add(dataSetTest.getLabels());
    }

    System.out.println("--- Labels ---");
    System.out.println(labels);

    eval.eval(labels, output);
    log.info(eval.stats());

什么是正确的迭代学习网络的milited数据集的方法?

萨克斯!

共有1个答案

金承嗣
2023-03-14

首先,始终使用nd4j.create(..)为了恩达里。永远不要使用实现。这允许您安全地创建ndarrays,无论您使用CPU还是GPU,这些ndarrays都将工作。

第二:始终使用RecordReaderDataSetiterator的构建器而不是构造器。它很长而且容易出错。

这就是为什么我们一开始就制作了builder。

除此之外,你做事情的方式是正确的。记录读取器为您处理批处理。

 类似资料: