将两个torch tensor拼接,并且维度的数量保持不变,比如原来是二维的两个tensor拼接完了之后仍然是二维的,不会变成三维的。
import torch
if __name__ == '__main__':
a = torch.tensor([[2,3,4],
[5,6,7]])
b = torch.tensor(list(range(3))*2).reshape(2,3)
c = torch.cat((a,b),dim=1)
print(a, b, c)
tensor([[2, 3, 4],
[5, 6, 7]]) tensor([[0, 1, 2],
[0, 1, 2]]) tensor([[2, 3, 4, 0, 1, 2],
[5, 6, 7, 0, 1, 2]])