오늘은 facebook에서 발표한 DomainBed를 정리하며 Domain Generalization의 대표 알고리즘을 알아보겠습니다.
우선 Domain Generalization에 대해 간단히 소개하자면, train과 test의 distribution shift (domain shift)를 완화할 수 있는 generalize된 기법을 고안하는 분야입니다. 대표적으로는 train과 test distribution 모두에서 invariant 한 feature를 추출하는 것에 집중하는 알고리즘 계열이 있습니다.
DomainBed는 facebook에서 발표한 DG 분야의 benchmark dataset과 알고리즘 등을 모아 정리한 논문입니다. DomainBed에 수록되어있는 데이터셋과 알고리즘 등은 우선 다음과 같습니다. (참고로 DomainBed는 DG 알고리즘이 출시될 때마다 계속 업데이트됩니다)
Dataset
위의 표에서 가로축은 8개의 서로 다른 벤치마크 dataset을 의미합니다. 총 8가지의 데이터셋은 다음과 같습니다. 이러한 8개의 대표 데이터셋을 이용해 기존의 DG 알고리즘들의 성능을 일관되게 측정한 후, 이를 모아 비교한 것이 DomainBed입니다.
그럼 이제 DomainBed에 수록된 몇가지 Domain Generalization (DG) 알고리즘들을 pytorch code와 함께 알아보겠습니다. 전반적인 code는 DomainBed의 공식 github을 참고해 loss 부분만 사용하기 편하도록 제가 수정했습니다. 또한 loss부분만 첨부했으니, 자세한 알고리즘은 논문을 참고 부탁드립니다. 관련 논문도 함께 링크로 첨부할테니, 잘못된 부분이 있다면 댓글로 남겨주세요! 👀
Algorithms
우선 알고리즘 소개에 들어가기 앞서 아래 알고리즘들은 공통된 code 구조를 공유하고 있습니다. 이는 다음과 같습니다.
(자세한 내용은 domainbed의 code 구조를 참고 부탁드립니다)
|
그럼 이제 전반적인 DomainBed의 DG 알고리즘과 loss 부분 pytorch code를 알아보겠습니다.
1) VReX
(Out-of-Distribution Generalization via Risk Extrapolation)
특정 iteration 이상에서 loss에 lambda penalty를 부여하는 알고리즘입니다. Lambda의 default는 0.01로 설정되어 있으며, 특정 epoch은 50으로 설정되어 있습니다. loss를 설정하는 방식은 다음 code와 같습니다.
def get_loss(self, minibatches):
loss, cls_loss, vrex_loss = 0, 0, 0
num_domains = len(minibatches)
## default hyperparam
PENALTY_ITERS = 50 # penalty anneal iters
LAMBDA = 0.01 # penalty weight (default)
nll = 0.
i = 0
losses = torch.zeros(num_domains)
if self.update_count >= PENALTY_ITERS:
penalty_weight = LAMBDA
else:
penalty_weight = 1.0
## domain별 반복문
for minibatch in minibatches:
x, y, ti, tt, l = minibatch
image_features = self.featurizer(x)
cls_outputs = self.classifier(image_features)
nll = F.cross_entropy(cls_outputs, y)
cls_loss += nll
losses[i] = nll
i += 1
mean = losses.mean()
penalty = ((losses - mean) **2).mean() # penalty 설정
vrex_loss = mean + penalty_weight * penalty
cls_loss /= num_domains
loss = vrex_loss + cls_loss
return OrderedDict({'loss': loss, 'cls_loss':cls_loss, 'vrex_loss': vrex_loss})
2) GroupDRO
(Robust ERM minimizes the error at the worst minibatch)
GroupDRO는 domain 별 loss에 따른 가중치를 부여한 register_buffer (학습되지 않는 layer)를 만든 후 최종 loss matrix에서 product하는 기법을 사용합니다. 여기서 q는 register_buffer를 뜻합니다.
def get_loss(self, minibatches):
loss, dro_loss = 0, 0
device = 'cuda' if minibatches[0][0].is_cuda else 'cpu'
GROUPDRO_ETA = 0.01 # default
if not len(self.q):
self.q = torch.ones(num_domains).to(device)
losses = torch.zeros(num_domains).to(device)
for m in range(len(minibatches)) :
x, y, ti, tt, l = minibatches[m]
image_features = self.featurizer(x)
cls_outputs = self.classifier(image_features)
losses[m] = F.cross_entropy(cls_outputs, y) # 4개의 도메인이 모두 담긴 losses들
# 도메인 loss에 따른 가중치 q
self.q[m] *= (GROUPDRO_ETA * losses[m].data).exp()
self.q /= self.q.sum() # q는 정규화 해줌
dro_loss = torch.dot(losses, self.q)
# 그리고 domain별로 나온 둘을 product 해서 전체 loss로 삼음
loss = dro_loss
return OrderedDict({'loss': loss, 'dro_loss': dro_loss})
References
[0] https://arxiv.org/pdf/1911.08731.pdf
[1] https://arxiv.org/pdf/2003.00688.pdf
'Computer Vision💖 > Domain (DA & DG)' 카테고리의 다른 글
[CV] Self-training에 대한 간단한 설명 - 가짜 라벨을 학습에 이용하기 (0) | 2022.09.02 |
---|---|
[CV] Test-Time Domain Adaptation의 의미와 간단 정리 (0) | 2022.05.08 |
[DG] Deep CORAL(CORelation ALignment, 2016) 논문리뷰 (0) | 2021.08.27 |