오늘은 Multitask Learning(MTL) 분야의 논문인 GradNorm(2018, ICML)에 대해 리뷰해보려고합니다. 다소 오래 전에 발표되었지만 아직까지 MTL 분야에서 밴치마크 성능으로 자주 등장하는 논문입니다. 제가 이해한 바를 정리한 것이니 잘못된 점은 댓글 부탁드립니다! :D
1. Introduction
# MTL과 Task Balancing
우리가 흔히 알고있는 딥러닝의 학습 체계는 Single-task learning(STL)이라고 할 수 있습니다. Multi-task learning의 방식은 딥러닝의 등장 전부터 존재했지만, 딥러닝이 등장하며 Single-task learning 학습체계는 Computer Vision 분야에서 인간을 능가하는 퍼포먼스를 보여주고 있습니다.
하지만 웨어러블 디바이스나, 로봇, 드론 등에는 여러가지 다양한 센서의 task를 동시에 효율적으로 수행해야 합니다. 따라서 이러한 상황에서 하나의 모델이 weight를 share하고, 하나의 forward pass로 multiple inferences를 생성하는 아래와 같은 형식의 Multi-task learning(MTL)이 점차 다시 각광받게 됩니다. 이러한 MTL 모델은 보다 robust하고 결과적으로 더 나은 성능을 보여준다는 장점이 있었습니다.
하지만 이러한 MTL의 단점은 Training이 어렵다는 것입니다. 여러개의 Imbalanced한 Task 들을 balancing 하는 것이 어렵기 때문이죠. 또한 shared feature는 모든 task에 유용하도록 robust하게 수렴되어야 합니다.
따라서 MTL에서 "Task Balancing"은 중요한 키워드가 되고, 이러한 여러개의 task 들을 balancing 하기 위해 여러 테크닉(Long & Wang, 2015) 등이 고안되게 됩니다.
GradNorm도 이러한 Balancing에 포커싱을 맞춘 paper이지만, 과거의 related works와 달리 backpropagated 되는 gradient 들의 Imbalance에 주목합니다. Gradient의 크기를 Multitask loss function의 튜닝으로 완화시켜 task 간의 imbalance를 막는 것이죠. 따라서 GradNorm의 주요 contribution은 다음과 같습니다.
# Contributions
- 직접적으로 gradient의 크기를 balancing 하도록 loss를 튜닝
- Grid search와 달리 단 하나의 hyperparameter로 작동한다 ($\alpha$)
- 서로 다른 task들의 gradient norms를 common scale에 놓아, 서로 상대적인 크기를 비교 가능하도록 함
- 서로 다른 task들이 비슷한 rates로 training 되도록 함
그럼 지금부터는 자세한 아키텍처에 대해 살펴보겠습니다.
2. Algorithms
# Preliminaries
알고리즘 소개에 앞서 간단한 notation과 preliminaries를 정리하고 넘어가겠습니다.
✔️ Loss Function
Multitask Learning에서는 주로 다음과 같이 각각 task의 loss에 weight를 매긴 $L(t) = \Sigma w_i(t)L_i(t)$ 이와 같은 형태의 loss를 사용합니다. (그냥 단순히 각각 loss에 가중치를 곱해 더한 값이며, 여기서 $i$는 각 task, $t$는 training time을 의미합니다.)
따라서 여기서 우리는 GradNorm을 통해 각 task loss의 weight인 $w_i(t)$의 function을 학습하는 것을 목표로 합니다. 이를 통해 우리의 목표인 서로 다른 task들의 학습을 similar rate로 조절할 수 있습니다. 아래 그림을 보시면 우측이 GN 적용 후이며, weight를 GradNorm Loss를 통해 학습하고 있음을 알 수 있습니다.
✔️ Notation
이제 GradNorm 알고리즘을 이해하기 위한 몇가지 Notation을 정리하겠습니다.
- $W$ : Full Network weights의 subset(부분집합)이라고 정의합니다. 주로 $W$는 마지막 shared feature layer의 weight에서 선택됩니다.
- $G_w^{(i)}(t)$ : 가중치가 매겨진 Singletask Loss $w_i(t)L_i(t)$의 gradient의 L2 norm 입니다. 수식은 다음과 같을 것입니다.
- $\bar{G}_W(t)$ : 특정 training time $t$에서 모든 task의 위 gradient norm의 average 값입니다. 이는 논문의 표현으로는 gradient 들의 common scale에 해당한다고 말합니다. 수식은 다음과 같습니다.
- $\tilde{L}_i(t)$ : 특정 training time $t$와 task $i$에서의 loss ratio 입니다. 이는 task $i$의 inverse training rate가 되며, training이 빠를 수록 이 값은 낮은 값을 갖게 됩니다. 역시 수식은 다음과 같습니다.
- $r_i(t)$ : 특정 task $i$의 상대 inverse trainig rate입니다. 이는 Gradient를 Balancing 하는데 사용하며, 이 값이 클수록 train을 encourage 할 수 있어야 할 것입니다.
이제 이러한 notation을 가지고 본격적으로 알고리즘을 정의해보겠습니다.
# Balancing Gradients with GradNorm
GradNorm은 앞서 말씀드린 각 task loss의 weight를 학습하기 위해 고안되었습니다. 따라서 GN loss를 새로 재정의하는데요, 이는 다음과 같습니다. GN paper에서 가장 중요한 수식이라고 할 수 있겠습니다. 그리고 이 loss는 오직 $w_i$에 대해서만 미분됩니다.
수식을 자세히 들여다보면 저희가 앞서 정의한 notation이 숨어있음을 알 수 있습니다. 기본적으로 MAE loss (L1 loss)이며, 아래와 같은 목표를 가진 loss라고 할 수 있습니다. 이를 해석해보면 아마 앞서 정의한 task별 gradient $G_w^{(i)}(t)$를 모든 task의 gradient의 average 값(즉, common training scale)과 inverse training rate를 곱한 값에 가까이 만들고 싶어하는 듯 합니다.
여기서 수식에 들어있는 $\alpha$는 GN loss의 유일한 하이퍼파라미터입니다. 만약 이 $\alpha$가 크다면, training rate balancing을 강하게 부과하는 효과가 나올 것입니다. ($\alpha$가 inverse training rate에 붙어있기 때문입니다) 따라서 다소 대칭적인 (symmetric) task들일수록, 이 하이퍼파라미터는 작게 설정해도 됩니다.
결론적으로 이러한 일련의 알고리즘을 아래와 같이 정리할 수 있겠습니다.