当前位置: 首页 > 工具软件 > jSelect > 使用案例 >

pytorch 代码:yield features.index_select(0,j), labels.index_select(0,j)

吕向荣
2023-12-01

pytorch 代码:yield features.index_select(0,j), labels.index_select(0,j)

yield features.index_select(0,j), labels.index_select(0,j)

yield 首先作用理解为return,它也可以返回一个或多个值,要想调用返回值就必须在循环中

index_select() 中第一个参数 0 表示以行为标准选择,例如j = tensor([1,2]),结果为选取features 第1,第2行数据

例子

import torch

a = torch.randn(5, 5, dtype=torch.float32)
print(a)
i=1
indices = list(range(5))
j = torch.LongTensor(indices[i:min(i+2,5)])
print(j)
b = a.index_select(0,j)
print(b)

结果及其说明

# a 的值
tensor([[-1.2528, -0.3235,  0.2825, -0.5463,  0.0053],
        [-0.3129,  0.4375,  0.4789, -0.3872,  0.1995],
        [-0.9480,  0.3840,  1.1145,  0.8569, -0.8921],
        [ 1.0255,  0.0352, -0.0806, -1.2422,  0.4661],
        [-0.5092, -1.2760, -0.1923,  0.2986, -0.7680]])
# j 的值
tensor([1, 2])

# b 的值,可以看到是来自 a 的 1,2 行
tensor([[-0.3129,  0.4375,  0.4789, -0.3872,  0.1995],
        [-0.9480,  0.3840,  1.1145,  0.8569, -0.8921]])
 类似资料: