当前位置: 首页 > 工具软件 > Interleave > 使用案例 >

【Pytorch】张量复制方法repeat、repeat_interleave和tile

嵇财
2023-12-01
import torch

一、repeat

以整个tensor作为基础元素进行复制操作。

1. 示例1:向量复制

x = torch.LongTensor(range(0,3))
print(x)
tensor([0, 1, 2])
print(x.repeat(2))
tensor([0, 1, 2, 0, 1, 2])
print(x.repeat(2,3)) # 0维上复制成2倍,1维上复制成3倍
tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])
print(x.repeat(2,3,2)) # 0维上复制成2倍,1维上复制成3倍,2维上也复制成2倍
tensor([[[0, 1, 2, 0, 1, 2],
         [0, 1, 2, 0, 1, 2],
         [0, 1, 2, 0, 1, 2]],

        [[0, 1, 2, 0, 1, 2],
         [0, 1, 2, 0, 1, 2],
         [0, 1, 2, 0, 1, 2]]])

2. 示例2:矩阵复制

x = torch.LongTensor(range(0,6)).reshape(2,3)
print(x)
# print(x.repeat(2)) # 会报错,因为x是2维的矩阵
tensor([[0, 1, 2],
        [3, 4, 5]])
print(x.repeat(1,2)) # 0维复制成1倍,1维复制成2倍
tensor([[0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5]])
print(x.repeat(2,2)) # 0维复制成2倍,1维复制成2倍
tensor([[0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5],
        [0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5]])

二、repeat_interleave

以tensor中的元素作为基础进行复制操作

1. 示例1:向量复制

x = torch.LongTensor(range(0,3))
print(x)
tensor([0, 1, 2])
print(x.repeat_interleave(2))
# print(x.repeat_interleave(2,3)) # 会报错
tensor([0, 0, 1, 1, 2, 2])

2. 示例2:矩阵复制

x = torch.LongTensor(range(0,6)).reshape(2,3)
print(x)
tensor([[0, 1, 2],
        [3, 4, 5]])
print(x.repeat_interleave(2)) # 将矩阵拉平后再复制
tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5])
print(x.repeat_interleave(2, dim=0)) # 在0维上复制成2倍
tensor([[0, 1, 2],
        [0, 1, 2],
        [3, 4, 5],
        [3, 4, 5]])
print(x.repeat_interleave(2, dim=1)) # 在1维上复制成2倍
tensor([[0, 0, 1, 1, 2, 2],
        [3, 3, 4, 4, 5, 5]])
print(x.repeat_interleave(torch.tensor([1,2]), dim=0)) # 第0行元素不复制,第2行元素复制成2倍
tensor([[0, 1, 2],
        [3, 4, 5],
        [3, 4, 5]])
print(x.repeat_interleave(torch.tensor([3,2,1]), dim=1)) # 第0列元素复制成3倍,第1列元素复制成2倍,第2列元素不复制
tensor([[0, 0, 0, 1, 1, 2],
        [3, 3, 3, 4, 4, 5]])

三、tile

大多数场景中与repeat相同,但是能够处理复制维度参数小于输入维度的情况

x = torch.LongTensor(range(0,6)).reshape(2,3)
print(x)
tensor([[0, 1, 2],
        [3, 4, 5]])
print(x.tile(2)) # 与repeat不同,tile不会报错,而是将其转换为x.tile(1,2)
tensor([[0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5]])
print(x.tile(1,2)) # 0维复制成1倍,1维复制成2倍
tensor([[0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5]])
print(x.tile(2,2)) # 0维复制成2倍,1维复制成2倍
tensor([[0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5],
        [0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5]])
 类似资料: