LoRa 논문은 큰 언어모델을 fine-tuning하기 위한 여러가지 방법 중 효과가 꽤 괜찮은 방법을 제시했는데 그 내용을 정리하고자 한다.
1. Introduction
매우 큰 언어모델을 만들고 이를 도메인에 맞게 fine-tuning하는 일은 매우 빈번하다. GPT3 계열의 모델은 크기가 매우 크기 때문에 원래의 모델은 freeze 해두고 태스크별로 별도 layer를 두는 방식을 사용한다. 하지만 모델 크기가 매우 크기 때문에 제대로 학습이 안될 수도 있고 inference 속도가 충분하지 않을 수 있다.
논문에서 제시하는 아이디어는 간단하다. 딥러닝 layer에서 hidden dimension의 차원이 아무리 커도 그 공간의 rank는 낮을 수 있다고 생각하는데, 이를 활용하는 것이다.
We take inspiration from Li et al. (2018a); Aghajanyan et al. (2020) which show that the learned over-parametrized models in fact reside on a low intrinsic dimension. We hypothesize that the change in weights during model adaptation also has a low “intrinsic rank”, leading to our proposed Low-Rank Adaptation (LoRA) approach
이렇게 구현한 fine-tuning 방법은 학습 속도가 빠르며 inference 속도저하도 없다.
2. LoRa method
MLP는 행렬 곱 연산을 하는 많은 layer로 구성되어 있다. 이 layer들은 보통 full-rank인데, pre-trained LM의 경우 full-rank가 아닌 낮은 'instrisic dimension'을 가지고 있어 더 낮은 차원으로 project 할 수 있다고 한다. 논문 저자들은 이에 착안하여 MLP layer들도 domain adaptaion시 낮은 'instrisic dimension'을 가지고 있다고 가정하여 실험을 진행하였다.
어떤 weight matrix $W_{0} \in \mathbb{R}^{d \times k}$가 있을 때, matrix decomposition을 사용하면 위 아이디어를 $W_{0} + \Delta W = W_{0} + BA, \ (B \in \mathbb{R}^{d \times r}, B \in \mathbb{R}^{r \times k}, r \ll min(d,k))$와 같이 표현할 수 있다. 모델 학습시에는 $W_{0}$는 update하지 않으며 $A,B$만 trainable parameter이다. Forward pass를 예로 들면,
$$h = W_{0}x + \Delta Wx = W_{0}x + BAx$$
와 같이 수행된다. $A$는 random Gaussian으로 초기화하고 $B$는 0으로 초기화화여 학습 시작 시에는 $\Delta W = BA = 0$이 되도록 한다. 또한 $\Delta Wx$를 $\frac{\alpha}{r}$으로 scaling 해준다.
참고로 Transformer에는 self-attention에서의 wieght matrix 4개 ($W_{q}, W_{k}, W_{v}, W_{o}$)와 2개의 MLP layer가 있다. 연구에서는 LoRa를 attention weight에만 적용하여 실험하였다.
3. Experiments
우선 학습 속도와 메모리 사용량이 크게 개선되었다.
On GPT-3 175B, we reduce the VRAM consumption during training from 1.2TB to 350GB. With r = 4 and only the query and value projection matrices being adapted, the checkpoint size is reduced by roughly 10,000× (from 350GB to 35MB)
또한 GPT3 전체를 fine-tuning하는 것 보다 더 좋은 성능이 나온다고 한다.
개인적인 경험으로는 P-tuning과 속도 차이도 크게 없고, 태스크별로 성능이 더 좋은 경우들도 많다. 무엇보다 P-tuning처럼 학습데이터를 매번 변경해야하는 번거로움이 없어서 좋다.