input:输入的张量(类型为Tensor)
repeats(类型:int或torch.Tensor):每个元素的重复次数。repeats参数会被广播来适应输入张量的维度
dim(类型:int)需要重复的维度。默认情况下,将把输入张量展平(flatten)为向量,然后将每个元素重复repeats次,并返回重复后的张量。
x = torch.arange(4).repeat_interleave(2)
print(x)
>>>
tensor([0, 0, 1, 1, 2, 2, 3, 3])
---------------------------------------
y = torch.tensor([[1, 2, 3], [4, 5, 6]])
torch.repeat_interleave(y, 2)
>>>
tensor([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])
-------------------------------------------
torch.repeat_interleave(y, 3, 0)#每一行重复3次
>>>
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
----------------------------------------
torch.repeat_interleave(y, 3, 1)#每一列中的每个元素重复3次
>>>
tensor([[1, 1, 1, 2, 2, 2, 3, 3, 3],
[4, 4, 4, 5, 5, 5, 6, 6, 6]])
---------------------------------------
torch.repeat_interleave(y, torch.tensor([3, 1]), dim=0)#指定元素重复不同次数
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[4, 5, 6]])