1. torch.cat()
torch.cat()
기능 정의는 다음과 같다.
Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension)
기준 차원에 대해 두 텐서의 요소를 연결하며(list의 extend), 기준 차원을 제외한 shape은 동일해야 한다.
import torch
t1 = torch.rand((6, 32))
t2 = torch.rand((4, 32))
torch.cat((t1, t2), dim=0).shape
>>>> torch.Size([10, 32])
torch.cat((t1, t2), dim=1).shape
>>>> torch.Size([4, 64])
t3 = torch.rand((4, 31))
torch.stack((t2, t3), dim=0).shape
>>>> RuntimeError: stack expects each tensor to be equal size, but got [4, 32] at entry 0 and [4, 31] at entry 1
2. torch.stack()
torch.stack()
기능 정의는 다음과 같다.
Concatenates a sequence of tensors along a new dimension. All tensors need to be of the same size.
새로운 차원을 기준으로 두 텐서를 연결하며, 대상이 되는 텐서의 모양이 모두 같아야 한다.
t1 = torch.tensor([[1,2,3],[4,5,6]])
t2 = torch.tensor([[-1,-2,-3],[-4,-5, -6]])
print(t1, "\n", t2)
>>>> tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[-1, -2, -3],
[-4, -5, -6]])
torch.stack([t1,t2], dim=0) # shape: [2, 2, 3]
>>>> tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[-1, -2, -3],
[-4, -5, -6]]])
torch.stack([t1,t2], dim=1) # shape: [2, 2, 3]
>>>> tensor([[[ 1, 2, 3],
[-1, -2, -3]],
[[ 4, 5, 6],
[-4, -5, -6]]])
728x90
'pytorch, tensorflow' 카테고리의 다른 글
[PyTorch] Embedding 추가하기 (0) | 2022.10.01 |
---|---|
[pytorch] torch.reshape에 관하여 (0) | 2022.07.11 |
[tensorflow 2.0] model.save, model.save_weights 차이 (0) | 2021.04.01 |
[pytorch] torch에서 parameter 접근하기 (0) | 2021.03.19 |