[Incremental Learning] Rehearsal-based 방법론을 훑어보자(ER-MIR, OCS)

2023. 3. 8. 01:35·Computer Vision💖/Continual Learning
반응형

0. Overview 

  • Replay-based 방법은 과거의 샘플들을 replay buffer나 generative model에 저장해놓고, current task 학습에 사용하는 방법이다. 이러한 과거 샘플들을 이용해 retraining을 하기도 하고, 현재 학습의 constraints로 사용하기도 한다. 목표는 여전히 classifier $f$의 파라미터 $\theta$를 학습하는 것이다. 
  • Online learning에서는 task가 주어지지 않고, input data의 iid도 보장되지 않는다. (=single-pass through the data) 
  • 아래 방법론들은 랜덤이 아니라 샘플링을 "잘"해야한다고 주장하지만, 그 샘플링의 단위는 각각 다르니 유의해야할 것 같다. 

 

1. ER-MIR 

Online Continual Learning with Maximally Interfered Retrieval (NeurIPS'19)

# Motivation 

  • replay buffer나 generative model을 가정할 때, 과거의 데이터 중 어떤 샘플을 replay 해야할지 결정하는 것은 중요하다. 
  • 따라서 이 논문에서는 M(memory buffer)에서 랜덤하게 샘플을 추출하는 것이 아니라, loss를 증가시키는(=maximally interfered) samples를 추출해 사용해야한다고 주장한다. 

 

# Methodology : Maximally Interfered Retrieval 

  • 크게 (1) Replay Memory 에서의 방법과 (2) Generative Model 에서의 방법으로 나뉜다. Main idea는 과거의 샘플에서 랜덤하게 뽑지 말고, new incoming sample들에게 maximally interfered 되는 샘플을 Memory buffer $M$에서 뽑자는 것이다. 

CIL 상황을 가정할 때, 기존의 방법(랜덤추출, 좌)과 MIR 방법 비교

[1] MIR from a Replay Memory 

  • $M$에서 top-k values를 뽑을 때 다음과 같은 score를 구해 $M$ 안에 있는 previous data를 선택한다. 여기서 $l(f_{\theta^*}(x), y))$는 해당 sample에 대해 best loss를 따로 저장해 놓은 값이다. 

 

 

[2] MIR from a Generative Model 

  • 파라미터 추정 전 후 loss의 차이가 maximize 되는 data points 들을 찾는게 목표이다. 아래 수식에서 $\theta^v$는 before parameter를, $\theta^'$는 after parameter를 나타낸다고 이해했다. 

 


2. A-GEM

Efficient Lifelog Learning with A-GEM (ICLR'19)
  • Continual Learning을 위한 (1) Learning protocol / (2) New metrics / (3) A-GEM (GEM의 발전된 버전) 을 제시한다. 

 

[1] Learning protocol 

  • 기존의 방법론처럼 single pass가 아니라, 데이터의 ordered sequences를 two streams로 나누어 학습한다. 먼저 $D^{CV}$는 model hypr-parameter selection을 위한 cross-validation용 데이터들이고(기존 CL 모델들은 초모수에 민감한 경우가 많았음), $D^{EV}$는 actual training/testing에 쓰이는 데이터셋이다. 

 

[2] Metrics : Learning Curve Area (LCA) 

  • 기존의 CL 분야에서 쓰이는 대표 metrics인 Avrage Accuracy(A)와 Forgetting Measure(F) 말고 Learning Curve Area(LCA)를 새롭게 제시한다. 

  • Average Accuracy : 마지막 task까지 train이 끝난 모델을 모든 task에 걸쳐서 test하고 average를 낸다.

  • Forgetting measure: "forgetting" $f_j^k$에 대해 현재 task와 과거 task들 사이의 accuracy 차이로 정의하고, average를 구한다. 이 Forgetting measure는 모델이 새로운 task를 얼마나 빨리 배우는지에 대한 지표가 될 수 있다. 

Forgetting 정의
Forgetting measure

  • Learning Curve Area (LCA) : 얼마나 모델이 학습을 빠르게 하는지를 측정. 높으면 learning을 빠르게 하는 것이다. 

 

[3] A-GEM (Averaged Gradient Episodic Memory) 

  • GEM(Gradient Episodic Memory) : Task 별로 memory budget을 나누기 때문에 iid와 task boundaries가 필요함. 그리고 memory내 과거의 모든 task $k$에 대해 loss 계산이 필요하다. 이는 다음과 같이 나타낼 수 있다. 

  • 하지만 이는 computational cost 측면에서 너무 intensive하다. 따라서 이러한 computational burden을 해결하기 위해 A-GEM 제안. 따라서 기존의 GEM과 달리 "average episodic memory loss"를 이용해 previous task가 증가되지 않도록 함. 
     


3. GSS

Gradient based sample selection for online continual learning (NeurIPS'19)

# Motivation 

  • Rehearsal/ Retraining based에서 old dataset은 current learning의 constraints를 제공하는데 사용될 수 있다. 
  • 하지만 기존의 replay buffer 방법(iCaRL, GEM)은 각 task 별로 buffer의 memory를 allocate 하기 때문에, task boundaries를 요구한다. 그리고 iid assumtion을 가정한다. 하지만 이는 현재 상황에서 available 하지 않으므로, 이를 해결하려고 한다. 

# Methodology : CL == Constrained Optimization 

  • Replay buffer를 구성하는 sample selection 문제를 constraint reduction problem으로 formulate한 논문. 목표는 original constraints의 feasibel region을 가장 잘 approximate할 수 있는 constraints의 fixed subset을 선택하는 것이다. 그리고 이는 replay buffer 내 샘플들의 diversity를 maximize하는 것과 같다. 

  • 목표는 current examples의 loss를 과거의 학습된 examples의 loss를 증가시키지 않고 최적화 하는 것. 따라서 original constraints는 다음과 같고, 아래는 이 constraints를 gradient space에서 나타낸 것이다. 

  • 하지만 위 original constraints에서 많은 수의 constraints는 과거 training sample이 많아짐에 따라 linear하게 증가하지만, 우리의 replay buffer $M$의 memory는 한정되어 있다. 따라서 여기서 중요해지는 문제는, "어떻게 replay buffer에 들어갈 데이터를 general한 setting으로 구성할 것인가?"가 된다. 


 

4. OCS

Online Coreset Selection For Rehearsal-Based Continual Learning (ICLR'22)

 

# Motivation 

  • 기존의 rehearsal-based 방법론들은 replay buffer에 들어갈 샘플을 random하게 선택한다. 하지만 realworld dataset은 imbalanced, noisy하기에 이렇게 랜덤으로 추출하다가는 (1) current task의 학습을 막고 (2) 이전의 task 학습도 잘 잊게하는 (=catastrophic forgetting) 대참사가 일어날 수 있다. 

realworld dataset example

  • Data point 하나하나가 CL 학습에 중요하다는 가설은 아래로 증명된다. Task 1 (MNIST에서 학습된 모델)을 Task 2 (CIFAR-10) 환경에서 single data point로 업데이트 할 때, per-class accuracy와 average forgetting 성능이 하나의 data point의 업데이트 만으로도 천차만별로 달라짐을 알 수 있다.

  • 따라서 이를 해결하기 위해 class-imbalanced와 noisy instance에 robust한 coreset을 고르는 Online Coreset Selection(OCS) 방법을 제안한다. 이러한 selection에는 3가지의 gradient-based similarity 기준이 들어가며, 이렇게 골라진 배치별 top-k selected data instance들로 학습을 진행한다. 이 selection 과정은 모델 업데이트 전에 발생한다. 

Overall architecture

  • 위 기존의 방법들은 data streams에 도착하는 데이터 자체를 필터링하지 않는다. 그리고 보통 buffer에 저장되는 previous data들도 random으로 선정된다. 

 

# Methodology : OCS 

  • 목표는 전체 데이터에서 training에 사용할 coreset을 고르자는 것이고, 이러한 과정을 통해 memory buffer 또한 coreset으로 구성되게 된다. (배치단위로 작동) 따라서 결과적으로 previous task에도 affinity하고, current task-adaptation에도 도움이 되는 것이 목표이다. 
  • 따라서 3가지의 selection criterion을 통해 gradient similarity를 maximize하는 방향의 coreset을 선정하는 것이 목표이다. 아래와 같이 공식으로 나타내진다. 

 

  • 특정 배치 내에서 datapoint를 선택하는 3가지의 selection criterion은 다음과 같다. 기본적으로 cosine 유사도를 사용한다. 특히 Minibatch similarity와 Sample diversity는 current task adaptation을 위한 것이고, Coreset affinity는 previous task의 catastropic forgetting을 방지하기 위한 것이라고 이해했다. 
  • Minibatch similarity : 하나의 datapoint가 현재의 task를 얼마나 잘 describe 하고 있는지를 판단할 수 있다. (특히, 선택된 examples 들이 largest minibatch similarity를 갖는다면 이 task instances 들의 variance는 낮다고 판단 가능) 

  • Sample diversity : 하나의 datapoint가 같은 배치 안에 있는 다른 데이터셋과 얼마나 비슷한지를 판단한다. 이는 negative similarity로, 단순히 average similarity를 구하는 것이 아니라 하나의 datapoint와 배치 내 다른 datapoint들의 dissimilarity를 계산한다. 따라서 sample이 다양하게 구성되도록 함. 

  • Coreset affinity : 선택된 coreset이 과거의 previous tasks의 샘플과 얼마나 유사한지를 판단한다. 

 

  • 따라서 위 3가지의 조건을 모두 조합해 most beneficial training instances($u^*$) 들을 뽑아내는 공식은 다음과 같다. 이렇게 뽑아낸 dataset에서 current task의 coreset도 추출하므로, coreset 또한 중요한 데이터만을 가지게 될 것이다. 

 

  • 위 과정들을 알고리즘으로 정리해보자. coreset이 buffer 역할을 한다고 이해했다. 

 

반응형
저작자표시 (새창열림)

'Computer Vision💖 > Continual Learning' 카테고리의 다른 글

[Incremental Learning] Hybrid-based 방법론을 훑어보자(RPS-Net ,FRCL)  (0) 2023.03.08
[Incremental Learning] Architecture-based 방법론을 짚어보자  (0) 2023.03.06
[Incremental Learning] Continual learning 갈래 짚어보기  (0) 2023.03.06
[Incremental Learning] Scalable and Order-robust Continual learning with Additive Parameter Decomposition 논문 리뷰  (0) 2023.03.06
[Incremental Learning] Lifelong Learning with Dynamically Expandable Networks(DEN) 논문 리뷰  (0) 2023.03.06
'Computer Vision💖/Continual Learning' 카테고리의 다른 글
  • [Incremental Learning] Hybrid-based 방법론을 훑어보자(RPS-Net ,FRCL)
  • [Incremental Learning] Architecture-based 방법론을 짚어보자
  • [Incremental Learning] Continual learning 갈래 짚어보기
  • [Incremental Learning] Scalable and Order-robust Continual learning with Additive Parameter Decomposition 논문 리뷰
당니이
당니이
씩씩하게 공부하기 📚💻
  • 당니이
    다은이의 컴퓨터 공부
    당니이
  • 전체
    오늘
    어제
    • 분류 전체보기 (136)
      • Achieved 👩🏻 (14)
        • 생각들 (2)
        • TIL (6)
        • Trial and Error (1)
        • Inspiration ✨ (0)
        • 미국 박사 준비 🎓 (1)
      • Computer Vision💖 (39)
        • Basic (9)
        • Video (5)
        • Continual Learning (7)
        • Generative model (2)
        • Domain (DA & DG) (5)
        • Multimodal (8)
        • Multitask Learning (1)
        • Segmentation (1)
        • Colorization (1)
      • RL 🤖 (1)
      • Autonomous Driving 🚙 (11)
        • Geometry (4)
        • LiDAR 3D Detection (1)
        • Trajectory prediction (2)
        • Lane Detection (1)
        • HDmap (3)
      • Linux (15)
      • PyTorch👩🏻‍💻 (10)
      • Linear Algebra (2)
      • Python (5)
      • NLP (10)
        • Article 📑 (1)
      • Algorithms 💻 (22)
        • Basic (8)
        • BAEKJOON (8)
        • Programmers (2)
      • ML (1)
        • 통계적 머신러닝(20-2) (1)
      • SQL (3)
      • 기초금융 💵 (1)
  • 블로그 메뉴

    • 홈
    • About me
  • 링크

    • 나의 소박한 github
    • Naver 블로그
  • 공지사항

  • 인기 글

  • 태그

    dfs
    Python
    continual learning
    CV
    자료구조
    NLP
    CL
    til
    domain generalization
    백트래킹
    리눅스
    코딩테스트
    pytorch
    Linux
    백준
    LLM
    알고리즘
    conda
    Incremental Learning
    domain adaptation
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
당니이
[Incremental Learning] Rehearsal-based 방법론을 훑어보자(ER-MIR, OCS)
상단으로

티스토리툴바