torch.repeat_interleave(input, repeats, dim=None) → Tensor
功能:沿着指定的维度重复数组的元素
输入:
input
:指定的数组repeats
:每个元素重复的次数,可以是张量或者是数组dim
:指定的维度注意:
如果不指定dim
,则默认将输入数组扁平化(维数是1,因此这时repeats
必须是一个数,不能是数组),并且返回一个扁平化的输出数组
返回的数组与输入数组维数相同,并且除了给定的维度dim
,其他维度大小与输入数组相应维度大小相同
repeats
:如果传入数组,则必须是tensor
格式。并且只能是一维数组,数组长度与输入数组input
的dim
维度大小相同,输入数组的具体意义如下:
如果
r
e
p
e
a
t
s
=
[
n
1
,
n
2
,
…
,
n
m
]
,
则输出
[
x
1
,
x
1
,
…
,
x
1
,
x
2
,
x
2
,
…
,
x
m
]
其中,
x
1
重复
n
1
次,
x
2
重复
n
2
次,
x
m
重复
n
m
次
如果repeats=[n_1,n_2,\dots,n_m],则输出[x_1,x_1,\dots,x_1,x_2,x_2,\dots,x_m]\\其中,x_1重复n_1次,x_2重复n_2次,x_m重复n_m次
如果repeats=[n1,n2,…,nm],则输出[x1,x1,…,x1,x2,x2,…,xm]其中,x1重复n1次,x2重复n2次,xm重复nm次
一般用法
import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,3,dim=0)
c=torch.repeat_interleave(a,3,dim=1)
print(a)
print(b)
print(c)
print(a.shape)
print(b.shape)
print(c.shape)
输出
# 原数组
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
# 沿第一维度重复后的数组
tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[5, 6, 7, 8, 9],
[5, 6, 7, 8, 9]])
# 沿第二维度重复后的数组
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4],
[5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]])
# 原数组形状
torch.Size([2, 5])
# 沿第一维度重复后的形状
torch.Size([6, 5])
# 沿第二维度重复后的形状
torch.Size([2, 15])
当不指定dim
时
import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,2)
print(a)
print(b)
输出
# 原数组
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
# 不指定dim时重复两次
tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9])
当repeats
为数组格式时
import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=0)
print(a)
print(b)
输出
# 原数组
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
# 第一行重复两次,第二行重复三次
tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[5, 6, 7, 8, 9],
[5, 6, 7, 8, 9]])
如果repeats
为数组,但是大小和输入的dim
大小不匹配,则会报错
import torch
a=torch.arange(10).view(2,5)
b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=1)
print(a)
print(b)
输出报错,RuntimeError:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-c8f2c85e38df> in <module>
1 import torch
2 a=torch.arange(10).view(2,5)
----> 3 b=torch.repeat_interleave(a,torch.tensor([2,3]),dim=1)
4 print(a)
5 print(b)
RuntimeError: repeats must have the same size as input along dim
Tensor.repeat(*sizes) → Tensor
功能:沿每个维度重复张量数组
输入:
sizes
:沿每个维度重复此张量的次数注意:
sizes
长度必须大于等于被重复数组tensor
的维数(如果tensor
的维数是2,则sizes
就必须大于等于2)import torch
a=torch.arange(10).view(2,5)
b=a.repeat(2,3,2)
print(a)
print(b)
print(a.shape)
print(b.shape)
输出
# 原数组
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
# 重复后的数组
tensor([[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
[5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
[5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
[5, 6, 7, 8, 9, 5, 6, 7, 8, 9]],
[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
[5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
[5, 6, 7, 8, 9, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
[5, 6, 7, 8, 9, 5, 6, 7, 8, 9]]])
# 原数组形状
torch.Size([2, 5])
# 重复后的数组形状
torch.Size([2, 6, 10])
如果sizes
长度小于tensor
的维数,则会报错
import torch
a=torch.arange(10).view(2,5)
b=a.repeat(2)
输出报错
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-33-19a278098c7c> in <module>
1 import torch
2 a=torch.arange(10).view(2,5)
----> 3 b=a.repeat(2)
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
两个函数方法最大的区别就是repeat_interleave
是一个元素一个元素地重复,而repeat
是一组元素一组元素地重复
torch.repeat_interleave():https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html#torch.repeat_interleave
torch.tensor.repeat():https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html#torch.Tensor.repeat
点个赞支持一下吧