当前位置: 首页 > 知识库问答 >
问题:

dl4j lstm不成功

安经纶
2023-03-14

我试图在这个链接的页面的中间部分复制exrcise:https://d2l.ai/chapple_recurrent-neurner-networks/sequence.html

练习使用一个正弦函数在-1到1之间创建1000个数据点,并使用一个递归网络来近似函数。

下面是我使用的代码。我要回去更多地研究为什么这不起作用,因为它对我来说没有太大意义,现在我可以很容易地使用前馈网络来近似这个函数。

      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

下面是一些我的结果的例子。蓝色为数据红色为结果

共有1个答案

高正初
2023-03-14

这是其中一次,你会从奇怪为什么这不起作用,到我最初的结果怎么会像他们一样好。

我的缺点是没有清楚地理解文档,也没有理解BPTT。

使用前馈网络,每个迭代存储为一行,每个输入存储为列。一个例子是[dataset.size,network inputs.size]

 类似资料:
  • 关于警告: 我尝试在不使用的情况下构建它。它构建了相同的

  • 问题内容: 我无法使用jquery的ajax功能成功发布。 运行页面的URL为,目标(Web服务)的URL为。没有端口是不一样的,它们分别是9999和8080。 下面是请求和jquery ajax代码。 请求: jQuery ajax代码: 问题答案: 这是跨域ajax调用的问题。基本上(至少在Firefox中),出于安全原因,POST请求会转换为OPTIONS请求。我昨晚碰到了同样 _的_事情,

  • 我正在学习LiquiBase。我正在尝试从changelog生成SQL。出于某种原因,它生成的唯一SQL是对表的锁定。 我期待drop、create和update表的SQL,但没有看到任何东西。 LiquiBase:3.4.1版 数据库:MS SQL Server

  • 我试图建立APK,以张贴我的请求在离子在游戏商店。但是当涉及到下面的命令时,我得到了错误: PS c:\projetos\xxx>jarsigner-verbose-sigalg sha1withrsa-digestalg sha1-keystore android.keystore platforms/android/app/build/outputs/apk/release/app-relea

  • 问题内容: 我有此代码: 在我的本地环境中,工作正常,但在服务器中,返回此错误: TypeError:$ http.get(…)。成功不是函数 有任何想法吗?谢谢 问题答案: 该语法是正确的高达角1.4.3。 对于Angular v.1.6以下的版本,必须使用method。该方法带有两个参数:a 和将与响应对象一起调用的回调。 使用该方法,将函数附加到返回的。 像这样: 请参阅此处的参考。 方法也

  • 我已经编写了一个php用户登录脚本,虽然我已经设法使注册页面正常工作(从而排除了common.php文件的内容是一个问题),并在mySQL中检查了数据库正在填充,但我似乎无法使登录本身发布任何内容,而不是不成功。 我肯定是在数据库中输入用户名和密码。有人能看出我哪里出了问题,或者建议我如何检查哪里出了问题吗? 表jmp_users的结构如下: 我的login.php页面是: