torch.flip(input,dim):第一个参数是输入,第二个参数是输入的第几维度,按照维度对输入进行翻转
import torch
x = torch.arange(16).view(2, 2, 2,2)
print('x=\n',x)
a = torch.flip(x, [2])
print('a=\n',a)
x=
tensor([[[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]]],
[[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]]])
a=
tensor([[[[ 2, 3],
[ 0, 1]],
[[ 6, 7],
[ 4, 5]]],
[[[10, 11],
[ 8, 9]],
[[14, 15],
[12, 13]]]])
Process finished with exit code 0
import torch
x = torch.arange(16).view(2, 2, 2,2)
print('x=\n',x)
a = torch.flip(x, [3])
print('a=\n',a)
x=
tensor([[[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]]],
[[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]]])
a=
tensor([[[[ 1, 0],
[ 3, 2]],
[[ 5, 4],
[ 7, 6]]],
[[[ 9, 8],
[11, 10]],
[[13, 12],
[15, 14]]]])
Process finished with exit code 0