实际上这是一个max返回值认知不清的误区
我遇到该报错原因是,torch.max函数返回值包含两个:第一个是max的tensor,第二个是max值的位置index.返回的最大值和索引各是一个tensor,分别表示该维度的最大值,以及该维度最大值的索引,一起构成元组(Tensor, LongTensor)
torch.max().values,
可以获取最大值tensor
torch.max().indices
可以获取最大值索引tensor
报错代码及修改如下
fmean = torch.mean(feats, dim=1, keepdim=True)
#原代码:fmax = torch.max(feats, dim=1, keepdim=True)
fmax = torch.max(feats, dim=1, keepdim=True).values#修改后
fsum = torch.cat((fmean, fmax), dim=1)#报错行