torch.reshape()
에 대해 궁금했던 부분들을 정리하고자 한다. Emil Bogomolov의 글을 참고하였다.
1. n-차원 텐서는 메모리 공간에서 어떻게 존재하는가?(Contiguous)
다수의 데이터에 접근할 때 데이터가 가까운 메모리 공간에 모여있으면 읽는 효율이 좋아진다. 때문에 배열 데이터도 연속된 공간안에 위치시키는게 좋다.
다차원 공간에서 배열 데이터를 저장하는 방법은 크게 row-major order와 column-major order가 있다(위키). C계열 언어는 row-major order를 따르며 이는 같은 행(row)에 데이터들이 연속된 메모리 공간에 위치해있는 것을 의미한다.
PyTorch도 row-major order로 다차원 데이터를 저장하는데 이를 contiguous라는 개념으로 지칭하고 있다.
Remark. By “contiguous” we understand “C contiguous”, i.e. the way arrays are stored in language C.
때문에 2차원 이상의 배열이라도 contiguous 하다면 메모리 공간상에서 중간에 빈 공간 없이 연속된 공간에 위치하게 된다.
2. torch.reshape의 데이터 copy 여부
reshape은 원본 데이터를 copy를 할 때도 있고, 안할 때도 있다.
Returns a tensor with the same data and number of elements as input, but with the specified shape. When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.
1) view
view는 메모리 공간상에서 데이터의 변화나 copy 없이 다차원 배열을 다루기 위해 만들어졌다. 때문에 view 결과는 항상 contiguous하다.
t = torch.rand(4, 4)
b = t.view(2, 8)
t.is_contiguous()
>>>> True
t.storage().data_ptr() == b.storage().data_ptr()
>>>> True
2) reshape
reshape는 텐서가 contiguous하면 자기 자신을 반환하나 non-contiguous하다면 데이터를 copy한다. 참고로 데이터 copy는 연산이나 메모리 측면에서 비효율적이다.
tens_A = torch.rand((2,3,4)) # 3-dimensional tensor of shape (2,3,4)
transp_A = tens_A.t()
print(transp_A.is_contiguous())
>>>> False
# contiguous한 텐서로 다시 만들어주기
tens_B = transp_A.contiguous() # triggers copying because wasn't contiguous
print(tens_B.is_contiguous())
>>>> True
# Reshaping
resh_transp_A = transp_A.reshape(shape=(2,2,2,3)) # triggers copying
resh_tens_B = tens_B.reshape(shape=(2,2,2,3)) # won't trigger copying
3. tranpose시 tensor 데이터 변화
3차원 이상의 배열을 transpose(또는 permute)할 때 데이터의 변화가 눈에 잘 안그려지기 마련이다. 예시만 정리하고자 한다.
te = torch.tensor(range(24))
te.view(3,2,4)
>>>>
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
te.view(3,2,4).transpose(0,1)
>>>>
tensor([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11],
[16, 17, 18, 19]],
[[ 4, 5, 6, 7],
[12, 13, 14, 15],
[20, 21, 22, 23]]])
te.reshape(3,2,4).transpose(0,1).reshape(-1)
>>>>
tensor([ 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 4, 5, 6, 7, 12, 13,
14, 15, 20, 21, 22, 23])
'pytorch, tensorflow' 카테고리의 다른 글
[PyTorch] Embedding 추가하기 (0) | 2022.10.01 |
---|---|
[PyTorch] torch.cat(), torch.stack() 비교 (0) | 2022.10.01 |
[tensorflow 2.0] model.save, model.save_weights 차이 (0) | 2021.04.01 |
[pytorch] torch에서 parameter 접근하기 (0) | 2021.03.19 |