Big Bird(내용은 여기 참고)의 핵심인 sparse attention은 GPU, TPU에서는 바로 적용하기 어렵습니다. Sliding window나 random element query 등으로 인해 때문에 parallel하게 attention을 적용하기 어렵기 때문인데, 본 연구에서는 'blockifying the lookups'이라는 방법으로 parallel attention을 구현했습니다.
Blockifying attention
핵심 아이디어는 block 단위의 attention입니다. query vector, key vector가 각각 12개씩 있다고 해보겠습니다. Block size가 2일때 query matrix를 12/2=6개의 block으로, key matrix도 12/2=6개의 block으로 구성합니다. 그리고 BigBird의 3개의 attention mechanism을 이 block 단위로 수행합니다. 이때 query block과 key block의 수는 동일해야 합니다.
i. Random Attention
각 query block이 임의로 선택된 $r$개의 key block과 attention 계산
ii. Window Local Attention
$j$index query block이 $[j - (w-1)/2, j+(w-1)/2]$ index의 key block들과 attention 계산
iii. Global Attention
Global block과 모든 block들과의 attention 계산
그런데 이를 따로따로 계산하게 된다면 gather 연산이 필요합니다. 이는 연산량을 늘려 비효율적인데 논문에서는 더 똑똑하게 이를 계산했습니다.
(0) Full Attention
query matrix, key matrix를 $Q,K \in R^{n \times d}$라 하고 $n$개의 token에 대해서 $Q_{i}=x_{i}W_{Q}, K_{i}=x_{i}W_{K}$라 하겠습니다. 그러면 full dense attention score(query, key 연산)은 아래 그림과 같이 표현될 수 있습니다.
(1) block attention
이제 block 단위 attention을 위해 reshape연산으로 $Q,K$를 각각 $[n/b] \times b \times d$ 텐서 $Q^{'}, K^{'}$로 만들어줍니다. 그리고 행렬 곱을 계산해주면
$$\begin{align} A_{jst} = \sum_{u}Q^{'}_{jsu}K^{'}_{jtu} \end{align}$$
tensor에서의 행렬곱 연산에 의해 $[n/b] \times b \times b$ 차원의 tensor $A$가 만들어지는데, 이는 axis=0 기준으로 $[n/b]$개 만큼 연산된 행렬곱 결과입니다. 즉 $Q,K$를 잘라서 $[n/b]$개의 block으로 만들고 각 block끼리 행렬곱을 한 것이며(block 단위 pairwise) 이는 [그림1] full attention score에서의 대각성분에 해당합니다.
(2) window attention
이 아이디어를 활용해 window attention도 구해보겠습니다. window attention은 $j$번째 query block이 $[j - (w-1)2, j+(w-1)/2]$범위의 key block들과 attention을 구하는 과정입니다. 이를 위해 reshape된 key tensor $K^{'} \in R^{[n/b] \times b \times d}$의 복사본(copy)을 $w$개 만큼 만듭니다. 그리고 이들을 각각 (★)roll 해주는데요, $j$번째 copy의 block들을 $j$개 만큼 밀어서 이동시켜줍니다(순환고리 처럼요). window 왼쪽에 있는 $[j-(w-1)/2, j-1]$개의 copy들의 block들은 각각 오른쪽으로 이동시키는데, 왼쪽에서 1번째 copy는 1칸씩, 왼쪽에서 2번째 copy는 2칸씩 이동시킵니다. 마찬가지로 window 오른쪽에 있는 $[j+1, j+(w-1)/2]$개의 copy들의 block들은 각각 왼쪽으로 이동시키는데 오른쪽에서 1번째 copy는 1칸씩 이동시킵니다. 그림으로 표현하면 아래와 같습니다. 그림으로 표현하면 아래와 같습니다.
좀 더 자세히 설명해보겠습니다. 1)block attention을 활용하면 full attention에서의 대각 성분들을 만들어 낼 수 있습니다. 그렇다면 key block 복사본을 만들고 각각 오른쪽 또는 왼쪽으로 roll해서 query block과 block attention을 계산하면 대각 성분으로만 채워진 tensor가 key block 복사본 수 만큼 나오며, 이 연산결과가 곧 window attention입니다.
[그림4]는 rolling된 key block copy들과 query block과의 attention 결과를 나타내고 있으며, [그림4] 하단 부분은 실제 attention score의 위치를 바꾸는 연산이 있는 것이 아니고 attention score block이 개념적으로 어느 곳에 위치하는 지를 나타낸 것입니다. 그리고 [그림4], [그림5]에서 모두 확인할 수 있듯 모서리 끝 쪽 attention block은 실제 window attention에 해당하지 않지만 rolling 결과물로 계산이 됩니다.
(3) random attention
random attention을 위해 $r$이라는 parameter를 두는데 각 query가 $r$개의 random key와 attention을 계산하게 하기 위한 parameter입니다. 이 과정은 random attention 계산 후 gather 연산을 통해 모아줍니다.
(4) global attention
global token에 해당하는 $g$개의 block들을 추가하여 attention 계산하는 것은 어렵지 않게 수행할 수 있습니다.
(5) dense attention 계산
위 과정의 핵심 목적은 sparse attention을 GPU/TPU에서 병렬 연산으로 효율적으로 계산하기 위해 dense attention 구조를 만드는 것입니다. dense attention을 위해 위 과정들을 compact tensor로 표현하면 다음과 같습니다.
(2),(3),(4)를 한꺼번에 모아 크기가 $[n/b] \times (g+w+r)b \times d$인 compact dense tensor $K^{''}$를 만들 수 있습니다. 즉, dense attention 계산은 $Q^{'}$(size: $[n/b] \times b \times d)$와 $K^{''}$(size:$[n/b] \times (g+w+r)b \times d$)의 행렬곱을 통해 attention score tensor(size:$[n/b] \times b \times (g+w+r)b$)를 만들게 됩니다(계산비용:$O(n(g+w+r)bd)$). 그리고 이 과정은 GPU/TPU에서 병렬 연산을 통해 빠르게 계산됩니다.
[그림7]에 대해서 조금 더 설명하겠습니다. 먼저 $Q1, K1$은 global token이기 때문에 모든 token들에 대해서 attention이 계산됩니다. 때문에 $Q1$에 대한 dense matrix를 gather하지 않아도 되기 때문에 그림7 왼쪽에서 $Q1$에 대해서 표현이 되어있지 않습니다. 그리고 그림7 왼쪽에서 파란색 block들은 $K1$을 제외한 key block들을 rolling하여 window attention을 적용한 것인데요, 앞에서도 언급했듯 $Q2, K6$와 $Q6, K2$는 window attention에 해당하지 않지만 dense attention 계산을 위해 어쩔 수 없이 계산되게 됩니다.
최종적으로는 attention score와 각 position에 해당하는 key vector들의 weighted sum을 통해 position별 output vector를 얻게 됩니다.
'논문 및 개념 정리' 카테고리의 다른 글
[2018] Deep contextualized word representations(ELMo) (0) | 2021.03.15 |
---|---|
[2017] Attention is All you Need (0) | 2021.03.15 |
[2019] Big Bird: Transformers for Longer Sequences (0) | 2021.02.15 |
Exploring Transfer Learning with T5 : the Text-To-Text Transfer Transformer (2) (0) | 2020.06.04 |
Exploring Transfer Learning with T5 : the Text-To-Text Transfer Transformer (1) (0) | 2020.06.01 |