오늘은 Domain Generalization의 대표 알고리즘인 Deep CORAL에 대해 간단히 리뷰해보도록 하겠습니다. 역시 잘못되거나 궁금한 부분은 댓글 부탁드립니다👀
Introduction
우리가 아는 대부분의 머신러닝 알고리즘들은 보통 IID (Independent & Identically distributed) 한 상황을 가정합니다. 하지만 아시다시피 이러한 IID 상황은 현실에서는 잘 보기 어렵습니다. 정말 이상적인 통계적인 상황을 가정한 것이기 때문이죠. 즉, 우리는 머신러닝 프로세스에서 보통 Domain shift를 겪습니다. 흔히 train data와 test data의 distribution이 달라 발생하는 기계학습의 한계인 것이죠.
따라서 Deep CORAL은 이러한 domain shift를 해결하는 Domain Adaptation 알고리즘의 일종입니다. 특히 Deep CORAL은 기존의 CORAL 방법론에 기초하고 있습니다. (자세한 기존 CORAL 방법론이 궁금하시면 Return of frustratingly easy domain adaptation(2016, AAAI)를 참고해주세요) 특히 기존 방법론의 논문 제목에 'frustratingly easy'가 들어가는만큼, CORAL은 정말 단순한 알고리즘입니다.
우선 CORAL은 CORelation ALignment 의 약자라고 할 수 있는데요, 말그래도 공분산을 정렬하여 domain shift를 최소화한다는 idea를 가지고 있습니다. 위 그림의 빨간색은 target domain data, 파랑색은 source domain data 라고 가정한다면, (a) 처럼 두 domain 사이에는 domain shift가 존재할 것입니다. 두 domain data는 zero mean으로 정규화 되었음에도 불구하고 다른 분포와 공분산을 보이고 있죠.
따라서 CORAL은 이러한 현상을 Target re-correlation으로 해결하려고 합니다. 이는 target domain의 correlation을 source domain의 feature에 추가하는 작업으로, 이러한 작업을 통해 우리는 (c) 처럼 source domain과 target domain의 분포가 잘 정렬되는 효과를 낼 수 있습니다. 그렇다면 확실히 classifier가 잘 작동할 수 있겠죠?
여기까지가 기존의 최초로 제안된 CORAL(2016, AAAI)에 대한 설명인데요, Deep CORAL 본 논문에서는 이러한 기존의 방법론이 linear transformation에 의존하며, end-to-end 모델이 아니라는 한계를 가진다고 지적합니다. 따라서 Deep CORAL에서는 non-linear transformation과 deep network를 통한 방법론을 제시하는데요, 그럼 이제부터 본격적으로 Deep CORAL 방법론에 대해 자세히 알아보겠습니다.
Architecture
Deep CORAL의 CNN을 이용한 전반적인 아키텍처는 다음과 같습니다. 특히 fc8 layer에 CORAL loss를 추가했는데요, 그럼 CORAL loss는 어떻게 구성되어 있는지 알아보겠습니다. 이 CORAL loss는 source, target domain의 feature의 공분산(covariances, second-order statistics) 의 distance를 다루는 부분이라고 할 수 있습니다. 이는 target domain에 더 잘 작용할 수 있도록 feature를 학습시키는 역할을 합니다. 수식으로 보면 다음과 같습니다.
여기서 Cs 와 Ct 는 각각 domain feature의 covariance matrix를 의미하며, Frobenius norm을 이용해 거리를 구하는 방법입니다. 이러한 coral loss를 input feature에 대해 gradient를 구하면 다음과 같은 결과를 얻을 수 있을 것입니다.
또한 기존 CORAL의 end-to-end 학습이 불가하다는 한계를 짚으며, 본 고에서는 CORAL loss를 활용한 end-to-end 학습을 제안합니다. 여기서 end-to-end learning이란 입력부터 출력까지 한번에 처리하는 네트워크를 뜻합니다.
특히 final deep feature는 classifier를 충분히 잘 학습시킬 만큼 discriminative 해야하며, source와 target domain에 invariant 해야할 것입니다. 따라서 이러한 장점을 모두 갖추기 위해 classification loss와 CORAL loss를 Joint training 합니다.
여기서 lambda는 CORAL loss에 대한 weight를 나타내는 하이퍼파라미터입니다.
Experiments
다른 알고리즘들과 비교했을 때, Deep CORAL은 꽤 준수한 성능을 보이고 있습니다. 또한 Domain Generalization 알고리즘 중에서도 꽤 준수한 성능을 보이고 있습니다.
지금까지 CORAL을 바탕으로 한 Deep CORAL에 대해 알아보았습니다☺️ 감사합니다!
References
[0] https://arxiv.org/pdf/1607.01719.pdf - Deep coral 논문
[1] https://arxiv.org/pdf/1511.05547.pdf - coral 논문
'Computer Vision💖 > Domain (DA & DG)' 카테고리의 다른 글
[CV] Self-training에 대한 간단한 설명 - 가짜 라벨을 학습에 이용하기 (0) | 2022.09.02 |
---|---|
[CV] Test-Time Domain Adaptation의 의미와 간단 정리 (0) | 2022.05.08 |
[DG] Domain Generalization의 대표 알고리즘을 DomainBed로 알아보자 (+ Code) (0) | 2021.08.06 |