官方说明:torch.cat
torch.cat(tensors, dim=0, *, out=None) → Tensor
连接给定维数的给定序列的序列张量。所有张量要么具有相同的形状(除了连接维度),要么为空。
import torch
# 总结:
# 1. torch.cat((x,y),dim=0) :张量 X,Y按照列堆起来
# 2. torch.cat((x,y),dim=1) :张量 X,Y按照行并排起来
x = torch.ones(3, 4)
y = torch.zeros(3, 4)
z = torch.cat((x, y), dim=0) # dim = 0 ;按列堆起来
m = torch.cat((x, y), dim=1) # dim = 1 :按行并排
print(f'x={x}')
print(f'y={y}')
print(f'z={z}')
print(f'm={m}')
x=tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
y=tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
z=tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
m=tensor([[1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0.]])