当前位置: 首页 > 工具软件 > Interleave > 使用案例 >

torch.repeat_interleave()与tensor.repeat()——数组的重复

谷光誉
2023-12-01

torch.repeat_interleave()与 tensor.repeat()——数组的重复

torch.repeat_interleave()

torch.repeat_interleave(input, repeats, dim=None) → Tensor

功能:沿着指定的维度重复数组的元素

输入:

  • input:指定的数组
  • repeats:每个元素重复的次数,可以是张量或者是数组
  • dim:指定的维度

注意:

  • 如果不指定dim,则默认将输入数组扁平化(维数是1,因此这时repeats必须是一个数,不能是数组),并且返回一个扁平化的输出数组

  • 返回的数组与输入数组维数相同,并且除了给定的维度dim,其他维度大小与输入数组相应维度大小相同

  • repeats:如果传入数组,则必须是tensor格式。并且只能是一维数组,数组长度与输入数组inputdim维度大小相同,输入数组的具体意义如下:
    如果 r e p e a t s = [ n 1 , n 2 , … , n m ] , 则输出 [ x 1 , x 1 , … , x 1 , x 2 , x 2 , … , x m ] 其中, x 1 重复 n 1 次, x 2 重复 n 2 次, x m 重复 n m 次 如果repeats=[n_1,n_2,\dots,n_m],则输出[x_1,x_1,\dots,x_1,x_2,x_2,\dots,x_m]\\其中,x_1重复n_1次,x_2重复n_2次,x_m重复n_m次 如果repeats=[n1,n2,,nm],则输出[x1,x1,,x1,x2,x2,,xm]其中,x1重复n1次,x2重复n2次,xm重复nm

案例代码

一般用法

import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,3,dim=0)
c=torch.repeat_interleave(a,3,dim=1)
print(a)
print(b)
print(c)
print(a.shape)
print(b.shape)
print(c.shape)

输出

# 原数组
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 沿第一维度重复后的数组
tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9],
        [5, 6, 7, 8, 9],
        [5, 6, 7, 8, 9]])
# 沿第二维度重复后的数组
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4],
        [5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]])
# 原数组形状
torch.Size([2, 5])
# 沿第一维度重复后的形状
torch.Size([6, 5])
# 沿第二维度重复后的形状
torch.Size([2, 15])

当不指定dim

import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,2)
print(a)
print(b)

输出

# 原数组
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 不指定dim时重复两次
tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9])

repeats为数组格式时

import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=0)
print(a)
print(b)

输出

# 原数组
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 第一行重复两次,第二行重复三次
tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9],
        [5, 6, 7, 8, 9],
        [5, 6, 7, 8, 9]])

如果repeats为数组,但是大小和输入的dim大小不匹配,则会报错

import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=1)
print(a)
print(b)

输出报错,RuntimeError:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-c8f2c85e38df> in <module>
      1 import torch
      2 a=torch.arange(10).view(2,5)
----> 3 b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=1)
      4 print(a)
      5 print(b)

RuntimeError: repeats must have the same size as input along dim

torch.Tensor.repeat()

Tensor.repeat(*sizes) → Tensor

功能:沿每个维度重复张量数组

输入:

  • sizes:沿每个维度重复此张量的次数

注意:

  • sizes长度必须大于等于被重复数组tensor的维数(如果tensor的维数是2,则sizes就必须大于等于2)

代码案例

import torch
a=torch.arange(10).view(2,5)
b=a.repeat(2,3,2)
print(a)
print(b)
print(a.shape)
print(b.shape)

输出

# 原数组
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
# 重复后的数组
tensor([[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9]],

        [[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
         [5, 6, 7, 8, 9, 5, 6, 7, 8, 9]]])
# 原数组形状
torch.Size([2, 5])
# 重复后的数组形状
torch.Size([2, 6, 10])

如果sizes长度小于tensor的维数,则会报错

import torch
a=torch.arange(10).view(2,5)
b=a.repeat(2)

输出报错

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-33-19a278098c7c> in <module>
      1 import torch
      2 a=torch.arange(10).view(2,5)
----> 3 b=a.repeat(2)

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

区别

两个函数方法最大的区别就是repeat_interleave是一个元素一个元素地重复,而repeat是一组元素一组元素地重复

官方文档

torch.repeat_interleave():https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html#torch.repeat_interleave

torch.tensor.repeat():https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html#torch.Tensor.repeat

点个赞支持一下吧

 类似资料: