当前位置: 首页 > 知识库问答 >
问题:

零预测,尽管在keras中为零填充小批量LSTM培训提供了掩蔽支持

松兴邦
2023-03-14

问题陈述

我在带标记的文本序列上训练keras中的多对多LSTM,使用预训练的GloVe嵌入来预测序列中每个元素的标记。我的训练机制包括小批量随机梯度下降,每个小批量矩阵都在列上添加了零填充,以确保输入到网络的长度相等。

关键的是,由于任务和数据的性质对我的小批量的自定义限制,我没有使用keras嵌入层。我的目标是为我的零填充单元实现掩蔽机制,以确保损失计算不会错误地将这些单元视为真正的数据点。

接近

正如keras文档中所解释的,keras有三种设置掩蔽层的方法:

  1. 配置keras.layers.嵌入mask_zero设置为True
  2. 添加keras.layers.Masking
  3. 调用重复层时手动传递掩码参数

因为我没有使用嵌入层对数据进行编码以进行训练,所以带有屏蔽嵌入层的选项(1)对我不可用。因此,我选择了(2),并在初始化模型后添加了一个掩蔽层。然而,这一变化似乎没有产生效果。事实上,我的模型不仅精度没有提高,而且在预测阶段,模型仍然生成零预测。为什么我的屏蔽层不屏蔽零填充单元?这可能与我在密集层中指定3个类而不是2个(因此将0作为一个单独的类)有关吗?

现有资源的限制

类似的问题已经被问到和回答,但我无法用它们来解决我的问题。虽然这篇文章没有收到直接的回应,但在评论中提到的一篇链接文章主要关注如何预处理数据以分配掩码值,这在这里是没有争议的。但是,掩蔽层初始化与此处使用的相同。这篇文章提到了同样的问题——掩蔽层对性能没有影响——答案和我一样定义了掩蔽层,但再次关注将特定值转换为掩蔽值。最后,本文中的答案提供了相同的层初始化,无需进一步阐述。

玩具数据生成

为了重现我的问题,我生成了一个包含两个类(1,2)的toy 10批次数据集。批次是一个可变长度的序列,后填充有0,最大长度为20个嵌入,每个嵌入向量由5个单元组成,因此input_shape=(20,5)。这两个类的嵌入值由不同但部分重叠的截断正态分布生成,从而为网络创建一个可学习但不平凡的问题。我把玩具数据包括在下面,这样你就可以重现这个问题了。

import pandas as pd
from keras.models import Sequential
from keras.layers import LSTM, Dense, TimeDistributed, Bidirectional, Dropout, Masking
from keras import optimizers

# *** model initialization ***

model = Sequential()
model.add(Masking(mask_value=0., input_shape=(20, 5))) # <- masking layer here
model.add(Bidirectional(LSTM(20, return_sequences=True), input_shape=(20, 5)))
model.add(Dropout(0.2))
model.add(TimeDistributed(Dense(3, activation='sigmoid')))

sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='sparse_categorical_crossentropy', optimizer=sgd, metrics=['mse'])

# *** model training ***

for epoch in range(10):

    for X,y in data_train:

        X = X.reshape(1, 20, 5)
        y = y.reshape(1, 20, 1)

        history = model.fit(X, y, epochs=1, batch_size=20, verbose=0)

# *** model prediction ***

preds = pd.DataFrame(columns=['true', 'pred'])

for index, (X,y) in enumerate(data_test):
    X = X.reshape(1, 20, 5)
    y = y.reshape(1, 20, 1)

    y_pred = model.predict_classes(X, verbose=0)

    df = pd.DataFrame(columns=['true', 'pred'])

    df['true'] = [y[0, i][0] for i in range(20)]
    df['pred'] = [y_pred[0, i] for i in range(20)]

    preds = preds.append(df, ignore_index=True)

# convert true labels to int & drop padded rows (where y_true=0)
preds['true'] = [int(label) for label in preds['true']]
preds = preds[preds['true']!=0]

这是带有掩蔽的模型的摘要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
masking_2 (Masking)          (None, 20, 5)             0         
_________________________________________________________________
bidirectional_4 (Bidirection (None, 20, 40)            4160      
_________________________________________________________________
dropout_4 (Dropout)          (None, 20, 40)            0         
_________________________________________________________________
time_distributed_4 (TimeDist (None, 20, 3)             123       
=================================================================
Total params: 4,283
Trainable params: 4,283
Non-trainable params: 0

我训练了一个模型,其中一个有屏蔽层,另一个没有屏蔽层,并使用以下方法计算精度:

np.round(sum(preds['true']==preds['pred'])/len(preds)*100,1)

对于没有掩蔽的模型,我得到了53.3%的准确率,对于有掩蔽的模型,我得到了33.3%的准确率。更令人惊讶的是,在这两个模型中,我一直将零作为预测标签。为什么掩蔽层不能忽略零填充单元?

复制问题的数据:

data_train = list(zip(X_batches_train, y_batches_train))
data_test = list(zip(X_batches_test, y_batches_test))

列车

[array([[-1.00612917,  1.47313952,  2.68021318,  1.54875809,  0.98385996,
          1.49465265,  0.60429106,  1.12396908, -0.24041602,  1.77266187,
          0.1961381 ,  1.28019637,  1.78803092,  2.05151245,  0.93606708,
          0.51554755,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.97596563,  2.04536053,  0.88367922,  1.013342  , -0.16605355,
          3.02994344,  2.04080806, -0.25153046, -0.5964068 ,  2.9607247 ,
         -0.49722121,  0.02734492,  2.16949987,  2.77367066,  0.15628842,
          2.19823207,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.31546283,  3.27420503,  3.23550769, -0.63724013,  0.89150128,
          0.69774266,  2.76627308, -0.58408384, -0.45681779,  1.98843041,
         -0.31850477,  0.83729882,  0.45471165,  3.61974147, -1.45610756,
          1.35217453,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.03329532,  1.97471646,  1.33949611,  1.22857243, -1.46890642,
          1.74105506,  1.40969261,  0.52465603, -0.18895266,  2.81025597,
          2.64901037, -0.83415186,  0.76956826,  1.48730868, -0.16190164,
          2.24389007,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.0676654 ,  3.08429323,  1.7601179 ,  0.85448051,  1.15537064,
          2.82487842,  0.27891413,  0.57842569, -0.62392063,  1.00343057,
          1.15348843, -0.37650332,  3.37355345,  2.22285473,  0.43444434,
          0.15743873,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.05258873, -0.17897376, -0.99932932, -1.02854121,  0.85159208,
          2.32349131,  1.96526709, -0.08398597, -0.69474809,  1.32820222,
          1.19514151,  1.56814867,  0.86013263,  1.48342922,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.1920635 , -0.48702788,  1.24353985, -1.3864121 ,  0.16713229,
          3.10134683,  0.61658271, -0.63360643,  0.86000807,  2.74876157,
          2.87604877,  0.16339724,  2.87595396,  3.2846962 ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.1380241 , -0.76783029,  0.18814436, -1.18165209, -0.02981728,
          1.49908113,  0.61521007, -0.98191097,  0.31250199,  1.39015803,
          3.16213211, -0.70891214,  3.83881766,  1.92683533,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.39080778, -0.59179216,  0.80348201,  0.64638205, -1.40144268,
          1.49751413,  3.0092166 ,  1.33099666,  1.43714841,  2.90734268,
          3.09688943,  0.32934884,  1.14592787,  1.58152023,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.77164353,  0.50293096,  0.0717377 ,  0.14487556, -0.90246591,
          2.32612179,  1.98628857,  1.29683166, -0.12399569,  2.60184685,
          3.20136653,  0.44056647,  0.98283455,  1.79026663,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-0.93359914,  2.31840281,  0.55691601,  1.90930758, -1.58260431,
         -1.05801881,  3.28012523,  3.84105406, -1.2127093 ,  0.00490079,
          1.28149304,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.03105486,  2.7703693 ,  0.16751813,  1.12127987, -0.44070271,
         -0.0789227 ,  2.79008301,  1.11456745,  1.13982551, -1.10128658,
          0.87430834,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.69710668,  1.72702833, -2.62599502,  2.34730002,  0.77756661,
          0.16415884,  3.30712178,  1.67331828, -0.44022431,  0.56837829,
          1.1566811 ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.71845983,  1.79908544,  0.37385522,  1.3870915 , -1.48823234,
         -1.487419  ,  3.0879945 ,  1.74617784, -0.91538815, -0.24244522,
          0.81393954,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.38501563,  3.73330047, -0.52494265,  2.37133716, -0.24546709,
         -0.28360782,  2.89384717,  2.42891743,  0.40144022, -1.21850571,
          2.00370751,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.27989188,  1.16254538, -0.06889142,  1.84133355,  1.3234908 ,
          1.29611702,  2.0019294 , -0.03220116,  1.1085194 ,  1.96495985,
          1.68544302,  1.94503544,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.3004439 ,  2.48768923,  0.59809607,  2.38155155,  2.78705889,
          1.67018683,  0.21731778, -0.59277191,  2.87427207,  2.63950475,
          2.39211459,  0.93083423,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.39239371,  0.30900383, -0.97307155,  1.98100711,  0.30613735,
          1.12827171,  0.16987791,  0.31959096,  1.30366416,  1.45881023,
          2.45668401,  0.5218711 ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.0826574 ,  2.05100254,  0.013161  ,  2.95120798,  1.15730011,
          0.75537024,  0.13708569, -0.44922143,  0.64834001,  2.50640862,
          2.00349347,  3.35573624,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.47135124,  2.10258532,  0.70212032,  2.56063126,  1.62466971,
          2.64026892,  0.21309489, -0.57752813,  2.21335957,  0.20453233,
          0.03106993,  3.01167822,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-0.42125521,  0.54016939,  1.63016057,  2.01555253, -0.10961255,
         -0.42549555,  1.55793753, -0.0998756 ,  0.36417335,  3.37126414,
          1.62151191,  2.84084192,  0.10831384,  0.89293054, -0.08671363,
          0.49340353,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.37615411,  2.00581062,  2.30426605,  2.02205839,  0.65871664,
          1.34478836, -0.55379752, -1.42787727,  0.59732227,  0.84969282,
          0.54345723,  0.95849568, -0.17131602, -0.70425277, -0.5337757 ,
          1.78207229,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.13863276,  1.71490034,  2.02677925,  2.60608619,  0.26916522,
          0.35928298, -1.26521844, -0.59859219,  1.19162219,  1.64565259,
          1.16787165,  2.95245196,  0.48681084,  1.66621053,  0.918077  ,
         -1.10583747,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.87763797,  2.38740754,  2.9111822 ,  2.21184069,  0.78091173,
         -0.53270909,  0.40100338, -0.83375593,  0.9860009 ,  2.43898437,
         -0.64499989,  2.95092003, -1.52360727,  0.44640918,  0.78131922,
         -0.24401283,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.92615066,  3.45437746,  3.28808981,  2.87207404, -1.60027223,
         -1.14164941, -1.63807699,  0.33084805,  2.92963629,  3.51170824,
         -0.3286093 ,  2.19108385,  0.97812366, -1.82565766, -0.34034678,
         -2.0485913 ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.96438618e+00,  1.88104784e-01,  1.61114494e+00,
          6.99567690e-04,  2.55271963e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 2.41578815e+00, -5.70625661e-01,  2.15545894e+00,
         -1.80948908e+00,  1.62049331e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 1.97017040e+00, -1.62556528e+00,  2.49469152e+00,
          4.18785985e-02,  2.61875866e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 3.14277819e+00,  3.01098398e-02,  7.40376369e-01,
          1.76517344e+00,  2.68922918e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 2.06250296e+00,  4.67605528e-01,  1.55927230e+00,
          1.85788889e-01,  1.30359922e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00]]),
 array([[ 1.22152427,  3.74926839,  0.64415552,  2.35268329,  1.98754653,
          2.89384829,  0.44589817,  3.94228743,  2.72405657,  0.86222004,
          0.68681903,  3.89952458,  1.43454512,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.02203262,  0.95065123,  0.71669023,  0.02919391,  2.30714524,
          1.91843002,  0.73611294,  1.20560482,  0.85206836, -0.74221506,
         -0.72886308,  2.39872927, -0.95841402,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.55775319,  0.33773314,  0.79932151,  1.94966883,  3.2113281 ,
          2.70768249, -0.69745554,  1.23208345,  1.66199957,  1.69894081,
          0.13124461,  1.93256147, -0.17787952,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.45089205,  2.62430534, -1.9517961 ,  2.24040577,  1.75642049,
          1.94962325,  0.26796497,  2.28418304,  1.44944487,  0.28723885,
         -0.81081633,  1.54840214,  0.82652939,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.27678173,  1.17204606, -0.24738322,  1.02761617,  1.81060444,
          2.37830861,  0.55260134,  2.50046334,  1.04652821,  0.03467176,
         -2.07336654,  1.2628897 ,  0.61604732,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 3.86138405,  2.35068317, -1.90187438,  0.600788  ,  0.18011722,
          1.3469559 , -0.54708828,  1.83798823, -0.01957845,  2.88713217,
          3.1724991 ,  2.90802072,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.26785642,  0.51076756,  0.32070756,  2.33758816,  2.08146669,
         -0.60796736,  0.93777509,  2.70474711,  0.44785738,  1.61720609,
          1.52890594,  3.03072971,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 3.30219394,  3.1515445 ,  1.16550716,  2.07489374,  0.66441859,
          0.97529244,  0.35176367,  1.22593639, -1.80698271,  1.19936482,
          3.34017172,  2.15960657,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.34839018,  2.24827352, -1.61070856,  2.81044265, -1.21423372,
          0.24633846, -0.82196609,  2.28616568,  0.033922  ,  2.7557593 ,
          1.16178372,  3.66959512,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.32913219,  1.63231852,  0.58642744,  1.55873546,  0.86354741,
          2.06654246, -0.44036504,  3.22723595,  1.33279468,  0.05975892,
          2.48518999,  3.44690602,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 0.61424344, -1.03068819, -1.47929328,  2.91514641,  2.06867196,
          1.90384921, -0.45835234,  1.22054782,  0.67931536,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.76480464,  1.12442631, -2.36004758,  2.91912726,  1.67891181,
          3.76873596, -0.93874096, -0.32397781, -0.55732374,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.39953353, -1.26828104,  0.44482517,  2.85604975,  3.08891062,
          2.60268725, -0.15785176,  1.58549879, -0.32948578,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.65156484, -1.56545168, -1.42771206,  2.74216475,  1.8758154 ,
          3.51169147,  0.18353058, -0.14704149,  0.00442783,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.27736372,  0.37407608, -1.25713475,  0.53171176,  1.53714914,
          0.21015523, -1.06850669, -0.09755327, -0.92373834,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-1.39160433,  0.21014669, -0.89792475,  2.6702794 ,  1.54610601,
          0.84699037,  2.96726482,  1.84236946,  0.02211578,  0.32842575,
          1.02718924,  1.78447936, -1.20056829,  2.26699318, -0.23156537,
          2.50124959,  1.93372501,  0.10264369, -1.70813962,  0.        ],
        [ 0.38823591, -1.30348049, -0.31599117,  2.60044143,  2.32929389,
          1.40348483,  3.25758736,  1.92210728, -0.34150988, -1.22336921,
          2.3567069 ,  1.75456835,  0.28295694,  0.68114898, -0.457843  ,
          1.83372069,  2.10177851, -0.26664178, -0.26549595,  0.        ],
        [ 0.08540346,  0.71507504,  1.78164285,  3.04418137,  1.52975256,
          3.55159169,  3.21396003,  3.22720346,  0.68147142,  0.12466013,
         -0.4122895 ,  1.97986653,  1.51671949,  2.06096825, -0.6765908 ,
          2.00145086,  1.73723014,  0.50186043, -2.27525744,  0.        ],
        [ 0.00632717,  0.3050794 , -0.33167875,  1.48109172,  0.19653696,
          1.97504239,  2.51595821,  1.74499313, -1.65198805, -1.04424953,
         -0.23786945,  1.18639347, -0.03568057,  3.82541131,  2.84039446,
          2.88325909,  1.79827675, -0.80230291,  0.08165052,  0.        ],
        [ 0.89980086,  0.34690991, -0.60806566,  1.69472308,  1.38043417,
          0.97139487,  0.21977176,  1.01340944, -1.69946943, -0.01775586,
         -0.35851919,  1.81115864,  1.15105661,  1.21410373,  1.50667558,
          1.70155313,  3.1410754 , -0.54806167, -0.51879299,  0.        ]])]

y___火车

[array([1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 1., 1., 2., 2., 1., 2., 0.,
        0., 0., 0.]),
 array([1., 1., 1., 1., 1., 2., 2., 1., 1., 2., 2., 1., 2., 2., 0., 0., 0.,
        0., 0., 0.]),
 array([1., 2., 1., 2., 1., 1., 2., 2., 1., 1., 2., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([2., 2., 1., 2., 2., 2., 1., 1., 2., 2., 2., 2., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([1., 2., 2., 2., 1., 1., 1., 1., 2., 2., 1., 2., 1., 1., 1., 1., 0.,
        0., 0., 0.]),
 array([2., 1., 2., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([1., 2., 1., 2., 2., 2., 1., 2., 2., 1., 1., 2., 1., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([2., 2., 1., 2., 1., 1., 1., 2., 1., 2., 2., 2., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([2., 1., 1., 2., 2., 2., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([1., 1., 1., 2., 2., 2., 2., 2., 1., 1., 1., 2., 1., 2., 1., 2., 2.,
        1., 1., 0.])]

X_检验

[array([[ 0.74119496,  1.97273418,  1.76675805,  0.51484268,  1.39422086,
          2.97184667, -1.35274514,  2.08825434, -1.2521965 ,  1.11556387,
          0.19776789,  2.38259223, -0.57140597, -0.79010112,  0.17038974,
          1.28075761,  0.696398  ,  3.0920007 , -0.41138503,  0.        ],
        [-1.39081797,  0.41079718,  3.03698894, -2.07333633,  2.05575621,
          2.73222939, -0.98182787,  1.06741172, -1.36310914,  0.20174856,
          0.35323654,  2.70305775,  0.52549713, -0.7786237 ,  1.80857093,
          0.96830907, -0.23610863,  1.28160768,  0.7026651 ,  0.        ],
        [ 1.16357113,  0.43907935,  3.40158623, -0.73923043,  1.484668  ,
          1.52809569, -0.02347205,  1.65349967,  1.79635118, -0.46647772,
         -0.78400883,  0.82695404, -1.34932627, -0.3200281 ,  2.84417045,
          0.01534261,  0.10047148,  2.70769609, -1.42669461,  0.        ],
        [-1.05475682,  3.45578027,  1.58589338, -0.55515227,  2.13477478,
          1.86777473,  0.61550335,  1.05781415, -0.45297406, -0.04317595,
         -0.15255388,  0.74669395, -1.43621979,  1.06229278,  0.99792794,
          1.24391783, -1.86484584,  1.92802343,  0.56148011,  0.        ],
        [-0.0835337 ,  1.89593955,  1.65769335, -0.93622246,  1.05002869,
          1.49675624, -0.00821712,  1.71541053,  2.02408452,  0.59011484,
          0.72719784,  3.44801858, -0.00957537,  0.37176007,  1.93481168,
          2.23125062,  1.67910471,  2.80923862,  0.34516993,  0.        ]]),
 array([[ 0.40691415,  2.31873444, -0.83458005, -0.17018249, -0.39177831,
          1.90353251,  2.98241467,  0.32808584,  3.09429553,  2.27183083,
          3.09576659,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.6862473 ,  1.0690102 , -0.07415598, -0.09846767,  1.14562424,
          2.52211963,  1.71911351,  0.41879894,  1.62787544,  3.50533394,
          2.69963456,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 3.27824216,  2.25067953,  0.40017321, -1.36011162, -1.41010106,
          0.98956203,  2.30881584, -0.29496046,  2.29748247,  3.24940966,
          1.06431776,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.80167214,  3.88324559, -0.6984172 ,  0.81889567,  1.86945352,
          3.07554419,  3.10357189,  1.31426767,  0.28163147,  2.75559628,
          2.00866885,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.54574419,  1.00720596, -1.55418837,  0.70823839,  0.14715209,
          1.03747262,  0.82988672, -0.54006372,  1.4960777 ,  0.34578788,
          1.10558132,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])]

y__试验

[array([1., 2., 2., 1., 2., 2., 1., 2., 1., 1., 1., 2., 1., 1., 2., 2., 1.,
        2., 1., 0.]),
 array([2., 2., 1., 1., 1., 2., 2., 1., 2., 2., 2., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])]

共有1个答案

柯栋
2023-03-14

第一个问题:你重塑后的X数据不是你所期望的。如果你看重塑后的第一个样本,它是:

array([[[-1.00612917,  1.47313952,  2.68021318,  1.54875809,
          0.98385996],
        [ 1.49465265,  0.60429106,  1.12396908, -0.24041602,
          1.77266187],
        [ 0.1961381 ,  1.28019637,  1.78803092,  2.05151245,
          0.93606708],
        [ 0.51554755,  0.        ,  0.        ,  0.        ,
          0.        ],
        [-0.97596563,  2.04536053,  0.88367922,  1.013342  ,
         -0.16605355],
        [ 3.02994344,  2.04080806, -0.25153046, -0.5964068 ,
          2.9607247 ],
        [-0.49722121,  0.02734492,  2.16949987,  2.77367066,
          0.15628842],
        [ 2.19823207,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.31546283,  3.27420503,  3.23550769, -0.63724013,
          0.89150128],
        [ 0.69774266,  2.76627308, -0.58408384, -0.45681779,
          1.98843041],
        [-0.31850477,  0.83729882,  0.45471165,  3.61974147,
         -1.45610756],
        [ 1.35217453,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 1.03329532,  1.97471646,  1.33949611,  1.22857243,
         -1.46890642],
        [ 1.74105506,  1.40969261,  0.52465603, -0.18895266,
          2.81025597],
        [ 2.64901037, -0.83415186,  0.76956826,  1.48730868,
         -0.16190164],
        [ 2.24389007,  0.        ,  0.        ,  0.        ,
          0.        ],
        [-1.0676654 ,  3.08429323,  1.7601179 ,  0.85448051,
          1.15537064],
        [ 2.82487842,  0.27891413,  0.57842569, -0.62392063,
          1.00343057],
        [ 1.15348843, -0.37650332,  3.37355345,  2.22285473,
          0.43444434],
        [ 0.15743873,  0.        ,  0.        ,  0.        ,
          0.        ]]])

所以实际上没有时间步被掩蔽,因为掩蔽层只掩蔽所有特征为0的时间步,所以上面的20个时间步没有被掩蔽,因为它们都不完全为0。

对于掩码层,为了确保成功地将掩码发送到输出层,您可以执行以下操作:

for i, l in enumerate(model.layers):
    print(f'layer {i}: {l}')
    print(f'has input mask: {l.input_mask}')
    print(f'has output mask: {l.output_mask}')

layer 0: <tensorflow.python.keras.layers.core.Masking object at 0x6417b7f60>
has input mask: None
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 1: <tensorflow.python.keras.layers.wrappers.Bidirectional object at 0x641e25cf8>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 2: <tensorflow.python.keras.layers.core.Dropout object at 0x641814128>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 3: <tensorflow.python.keras.layers.wrappers.TimeDistributed object at 0x6433b6ba8>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("time_distributed/Reshape_3:0", shape=(None, 20), dtype=bool)

因此,您可以看到最后一层还具有输出_掩码,这意味着成功地传播了掩码。您似乎误解了掩蔽在Keras中的工作原理,它实际上会生成一个掩蔽,这是一个布尔数组,掩蔽的形状是(无,时间步),因为在您的模型定义中,时间步维度始终保持不变,因此掩蔽将被传播到最后,而不会发生任何更改。然后,当Keras计算损失时(当然,当它计算梯度时),掩码值为False的时间步长将被忽略。屏蔽层不会更改输出值,当然,您的模型仍然会预测类0,它只会生成一个布尔数组,指示应跳过哪个时间步并将其传递到末尾(如果所有层都接受屏蔽)。

因此,您可以做的是更改模型定义的一行,如下所示,并使您的y_labels移位1,这意味着您当前的类:

0 -

1 -

2 -

# I would prefer softmax if doing classification
# here we only need to specify 2 classes
# and actually TimeDistributed can be thrown away (at least in recent Keras versions)
model.add(TimeDistributed(Dense(2, activation='softmax')))

你也可以在这里看到我的答案https://stackoverflow.com/a/59313862/11819266 了解如何在有/无掩蔽的情况下计算损耗。

 类似资料:
  • 我正在Keras培训一名LSTM: 每个单元的输入是一个已知的2048向量,不需要学习(如果您愿意,它们是输入句子中单词的ELMo嵌入)。因此,这里没有嵌入层。 由于输入序列具有可变长度,因此使用

  • 抱歉,如果这个问题已经提出,我已经做了深入的搜索,什么都没有。 现在,我知道: 会在左边用0填充我的价格,所以25的价格会导致00025 如果我想把它们垫在右边,结果是25000呢?如何仅使用String.format模式?

  • 问题内容: 抱歉,如果已经提出此问题,我已经进行了深入搜索,什么也没有。 现在,我知道: 会将我的价格填充到左侧的零,因此价格为25将得出 00025 如果我想将它们向右填充,结果是 25000 怎么办?我该如何 仅 使用 String.format 模式呢? 问题答案: 您可以使用: 我可以只使用格式模式吗? 就像方法一样使用。从这篇文章中,您将看到输出空间是硬编码的,因此必须使用。

  • 问题内容: 我对PostgreSQL相对较新,并且我知道如何在SQL Server中用左数零填充数字,但是我在PostgreSQL中努力解决这个问题。 我有一个数字列,其中最大位数为3,最小位数为1:如果是一位,它的左边有两个零,如果是两位,则有1,例如001、058、123。 在SQL Server中,我可以使用以下命令: 在PostgreSQL中不存在。任何帮助,将不胜感激。 问题答案: 您可

  • 问题内容: 在数组末尾加零的更Python方式是什么? 在我的实际用例中,实际上我想将数组填充到最接近的1024倍数。例如:1342 => 2048,3000 => 3072 问题答案: 使用mode可以满足您的需要,在这里我们可以传递一个元组作为第二个参数来告诉每个大小要填充多少个零,例如a将在左边填充 2个 零,在右边填充 3个 零: 给出为: 也可以通过将元组的元组作为填充宽度来填充2D n

  • 我正在尝试使用keras对图像进行二值分类。 我的CNN模型对训练数据进行了良好的训练(训练准确率约为90%,验证准确率约为93%)。但在培训期间,如果我将批量大小设置为15000,则得到图I输出,如果我将批量大小设置为50000,则得到图II输出。有人能告诉我怎么了吗?预测不应该取决于批量大小,对吗? 我用于预测的代码: 我的型号:-