当前位置: 首页 > 工具软件 > Masked Label > 使用案例 >

[Pytorch函数] .masked_fill_()

施阳曜
2023-12-01

masked_fill_(mask, value)
掩码操作
用value填充tensor中与mask中值为1位置相对应的元素。mask的形状必须与要填充的tensor形状一致。

a = torch.randn(5,6)

x = [5,4,3,2,1]
mask = torch.zeros(5,6,dtype=torch.float)
for e_id, src_len in enumerate(x):
    mask[e_id, src_len:] = 1
mask = mask.to(device = 'cpu')
print(mask)
a.data.masked_fill_(mask.byte(),-float('inf'))
print(a)
----------------------------输出
tensor([[0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1.]])
tensor([[-0.1053, -0.0352,  1.4759,  0.8849, -0.7233,    -inf],
        [-0.0529,  0.6663, -0.1082, -0.7243,    -inf,    -inf],
        [-0.0364, -1.0657,  0.8359,    -inf,    -inf,    -inf],
        [ 1.4160,  1.1594,    -inf,    -inf,    -inf,    -inf],
        [ 0.4163,    -inf,    -inf,    -inf,    -inf,    -inf]])

 类似资料: