# Motivation
- Continual learning 시나리오에서 데이터가 들어올 때마다 기존의 network를 fine-tuning 하기에는 너무 비효율적이며, 전체를 new task를 위해 rerain 하게되면 기존의 task 학습 능력도 저하되는 문제가 발생
- 기존의 Regularization을 주는 Continual learning 방법들은(a) 전체 network를 previous task에 의존해 retrain해 original task와 새로운 task의 파라미터를 가깝게 위치하도록 만들어 주는 방법이었지만, 굳이 현재 task에 도움이 되는 파라미터까지 entire retrain할 필요는 없다.
- 따라서 DEN(c)은 selectively하게 retrain하고, 기존의 학습으로 해결할 수 없는 새로운 task는 network의 capacity를 선택적으로 늘리는 방향으로 dynamic하게 학습을 진행한다.
- 여기서 DEN의 과제는 (1) growing capacity에 대한 효율성 잡기 / (2) 언제 얼마나 network를 expand할지 결정하기 / (3) Continual learning의 과제인 semantic drift와 catastrophic forgetting을 막기 정도로 나뉠 수 있겠다.
# Methodology : Dynamically Expandable Network
- Online continual learning setting에 맞게 current time t 시점의 데이터만 access 할 수 있으며, 과거 t-1 시점까지 데이터는 접근이 제한된다. (하지만 과거 모델 파라미터 접근은 가능함) 따라서 여기서 목표는 아래와 같이 $W^t$(L개의 layer들의 파라미터가 들어있는 weight tensor)의 모델 파라미터를 과거 t-1시점까지 학습된 모델 파라미터들을 이용해 learning 하는 것이다. (뒤에 붙어있는 람다항은 regularization 항이다)
- 크게 네트워크는 (1) Selective retraining / (2) Dynamic network expansion / (3) Network split/duplication으로 나뉜다. (그림 순서대로) 목표는 과거 학습으로 해결하지 못하는 새로운 태스크를 만났을 때 network capacity를 dynamically하게 확장하면서, 과거의 유용한 정보는 최대로 이용하는 것이다. 이러한 과정을 task 시점별로 반복한다고 보면 된다.
[1] Selective Retraining
- 전체 데이터를 retraining 하는 것보다 selective retraining하는 것이 효율적이다. 따라서 selective retraining할 subnetwork에 초점을 맞춘다.
- 일단 t=1일 때는 l1-regularization을 사용해 sparse한 network를 만든다. (l1은 여기서 weight가 정확히 0으로 떨어지도록 유도해 현재 task에 중요한 weight를 판별하는 역할. sparse한 nework는 computation overhead를 막기 위함)
- 그리고 L-1 layer까지의 파라미터들을 모두 $W^{t-1}$로 고정시켜놓고, L-1의 hidden units들과 task t의 output unit $o_t$의 connection을 얻어 어떤 unit과 weight가 task t의 training 과정에 영향을 주고 있는지 브루트포스(너비우선탐색, BFS)로 노드를 내려가며(L-1 ~.. 1) 확인한다. 그리고 output unit $o_t$와 직접적으로 connected 되어있지 않은(=weight가 0인) network는 버려버린다.
- 이러한 과정으로 관련이 있는 weights만 남겨 subnetwork $S$를 구성한다. 그리고 이렇게 남겨진 weight들을 $W_S^t$라고 하고, 이 weight만 부분적으로 learning한다.
[2] Dynamic Network Expansion
- 기존 학습된 내용으로 새로운 task를 학습하기 어려울 경우 expansion이 필요하다.
- 학습하기 어려울 경우? 는 loss $L$이 특정 threshold보다 높을 때로 정의한다. 따라서 이 expansion과정은 loss 조건문이 포함된다.
- 이러한 expansion을 항상 적용하는게 아니라 group sparsity regularization을 이용해 결정한다. 여기서 group은 각 뉴런의 incoming weight로 결정된다. 이러한 regularization 과정을 통해 모든 레이어에 더해졌던 k units 중 쓸데 없는 unit은 제거한다.
[3] Network split / duplication
- 전통적인 CL에서 semantic drift, catastrophic forgetting은 중요한 문제이며, 전통적으로 다음과 같은 규제로 해결해왔다. 과거 t-1시점까지 학습된 weight를 현재 t시점의 weight와 l2항으로 가깝게 만드는 것이다. (기존의 EWC 등의 방법)
- 하지만 위 방법은 task가 많아질 때 최적화가 어려울 수 있으므로, 뉴런을 split하는 방법을 제안한다. 이는 각 hidden unit에서 semantic drift가 일어난 양을 l2 distance로 계산해 $p_i^t$로 정의하고, 이러한 변화가 특정 threshold보다 크면 semantic drift가 일어났다고 가정하고 뉴런을 copy하는 것이다. copy된 후에는 전체 network가 다시 train된다.
반응형