[DG] Domain Generalization의 대표 알고리즘을 DomainBed로 알아보자 (+ Code)

2021. 8. 6. 23:10·Computer Vision💖/Domain (DA & DG)
반응형

오늘은 facebook에서 발표한 DomainBed를 정리하며 Domain Generalization의 대표 알고리즘을 알아보겠습니다.

우선 Domain Generalization에 대해 간단히 소개하자면, train과 test의 distribution shift (domain shift)를 완화할 수 있는 generalize된 기법을 고안하는 분야입니다. 대표적으로는 train과 test distribution 모두에서 invariant 한 feature를 추출하는 것에 집중하는 알고리즘 계열이 있습니다. 

DomainBed는 facebook에서 발표한 DG 분야의 benchmark dataset과 알고리즘 등을 모아 정리한 논문입니다. DomainBed에 수록되어있는 데이터셋과 알고리즘 등은 우선 다음과 같습니다. (참고로 DomainBed는 DG 알고리즘이 출시될 때마다 계속 업데이트됩니다) 

DomainBed에 현재 수록되어있는 벤치마크 dataset과 알고리즘

 

Dataset

위의 표에서 가로축은 8개의 서로 다른 벤치마크 dataset을 의미합니다. 총 8가지의 데이터셋은 다음과 같습니다. 이러한 8개의 대표 데이터셋을 이용해 기존의 DG 알고리즘들의 성능을 일관되게 측정한 후, 이를 모아 비교한 것이 DomainBed입니다. 

DomainBed의 8가지 벤치마크 데이터셋 

그럼 이제 DomainBed에 수록된 몇가지 Domain Generalization (DG) 알고리즘들을 pytorch code와 함께 알아보겠습니다. 전반적인 code는 DomainBed의 공식 github을 참고해 loss 부분만 사용하기 편하도록 제가 수정했습니다. 또한 loss부분만 첨부했으니, 자세한 알고리즘은 논문을 참고 부탁드립니다. 관련 논문도 함께 링크로 첨부할테니, 잘못된 부분이 있다면 댓글로 남겨주세요! 👀


Algorithms 

우선 알고리즘 소개에 들어가기 앞서 아래 알고리즘들은 공통된 code 구조를 공유하고 있습니다. 이는 다음과 같습니다.
(자세한 내용은 domainbed의 code 구조를 참고 부탁드립니다)

  • minibatch 단위로 domain 별 data를 가져옵니다.
  • self.featurizer(x) 는 image를 featurize 하는 부분으로, Convnet 기반의 network를 사용하면 됩니다. (ResNet 등)
  • self.classifier(x) 는 image feature를 통해 분류를 수행하는 부분으로, 단순 FC layer 등을 사용하면 됩니다. 
  • classifier는 다중분류기로 가정하고, loss는 이에 따라 cross_entropy로 가정합니다. 

그럼 이제 전반적인 DomainBed의 DG 알고리즘과 loss 부분 pytorch code를 알아보겠습니다. 

 

1) VReX

(Out-of-Distribution Generalization via Risk Extrapolation) 

특정 iteration 이상에서 loss에 lambda penalty를 부여하는 알고리즘입니다. Lambda의 default는 0.01로 설정되어 있으며, 특정 epoch은 50으로 설정되어 있습니다. loss를 설정하는 방식은 다음 code와 같습니다. 

def get_loss(self, minibatches):
        loss, cls_loss, vrex_loss = 0, 0, 0
        num_domains = len(minibatches)
        ## default hyperparam  
        PENALTY_ITERS =  50             # penalty anneal iters 
        LAMBDA =  0.01               # penalty weight (default)
        nll = 0.
        i = 0
        losses = torch.zeros(num_domains)

        if self.update_count >= PENALTY_ITERS:   
            penalty_weight = LAMBDA       
        else:
            penalty_weight = 1.0

        ## domain별 반복문 
        for minibatch in minibatches:
            x, y, ti, tt, l = minibatch
            image_features = self.featurizer(x)
            cls_outputs = self.classifier(image_features)
            nll = F.cross_entropy(cls_outputs, y)
            cls_loss += nll
            losses[i] = nll
            i += 1
        mean = losses.mean()
        penalty = ((losses - mean) **2).mean()  # penalty 설정

        vrex_loss = mean + penalty_weight * penalty 
        cls_loss /= num_domains
        loss = vrex_loss + cls_loss
        return OrderedDict({'loss': loss, 'cls_loss':cls_loss, 'vrex_loss': vrex_loss})

 

2) GroupDRO 

(Robust ERM minimizes the error at the worst minibatch)

GroupDRO는 domain 별 loss에 따른 가중치를 부여한 register_buffer (학습되지 않는 layer)를 만든 후 최종 loss matrix에서 product하는 기법을 사용합니다. 여기서 q는 register_buffer를 뜻합니다. 

def get_loss(self, minibatches):
			loss, dro_loss = 0, 0
      device = 'cuda' if minibatches[0][0].is_cuda else 'cpu'
      GROUPDRO_ETA = 0.01   # default    
      if not len(self.q):
          self.q = torch.ones(num_domains).to(device)

      losses = torch.zeros(num_domains).to(device)
      for m in range(len(minibatches)) :
          x, y, ti, tt, l = minibatches[m]
          image_features = self.featurizer(x)
          cls_outputs = self.classifier(image_features)
          losses[m] = F.cross_entropy(cls_outputs, y)     # 4개의 도메인이 모두 담긴 losses들 
          # 도메인 loss에 따른 가중치 q
	self.q[m] *= (GROUPDRO_ETA * losses[m].data).exp()       

      self.q /= self.q.sum()    # q는 정규화 해줌
      dro_loss = torch.dot(losses, self.q)    
			# 그리고 domain별로 나온 둘을 product 해서 전체 loss로 삼음 
      loss = dro_loss
      return OrderedDict({'loss': loss, 'dro_loss': dro_loss})

 

 

 


References 


[0] https://arxiv.org/pdf/1911.08731.pdf
[1] https://arxiv.org/pdf/2003.00688.pdf

 

반응형
저작자표시 (새창열림)

'Computer Vision💖 > Domain (DA & DG)' 카테고리의 다른 글

[Daily] CLIMB: CLustering-based Iterative Data Mixture Bootstrapping for Language Model Pre-training  (1) 2025.04.19
[CV] Self-training에 대한 간단한 설명 - 가짜 라벨을 학습에 이용하기  (0) 2022.09.02
[CV] Test-Time Domain Adaptation의 의미와 간단 정리  (0) 2022.05.08
[DG] Deep CORAL(CORelation ALignment, 2016) 논문리뷰  (0) 2021.08.27
'Computer Vision💖/Domain (DA & DG)' 카테고리의 다른 글
  • [Daily] CLIMB: CLustering-based Iterative Data Mixture Bootstrapping for Language Model Pre-training
  • [CV] Self-training에 대한 간단한 설명 - 가짜 라벨을 학습에 이용하기
  • [CV] Test-Time Domain Adaptation의 의미와 간단 정리
  • [DG] Deep CORAL(CORelation ALignment, 2016) 논문리뷰
당니이
당니이
씩씩하게 공부하기 📚💻
  • 당니이
    다은이의 컴퓨터 공부
    당니이
  • 전체
    오늘
    어제
    • 분류 전체보기 (136)
      • Achieved 👩🏻 (14)
        • 생각들 (2)
        • TIL (6)
        • Trial and Error (1)
        • Inspiration ✨ (0)
        • 미국 박사 준비 🎓 (1)
      • Computer Vision💖 (39)
        • Basic (9)
        • Video (5)
        • Continual Learning (7)
        • Generative model (2)
        • Domain (DA & DG) (5)
        • Multimodal (8)
        • Multitask Learning (1)
        • Segmentation (1)
        • Colorization (1)
      • RL 🤖 (1)
      • Autonomous Driving 🚙 (11)
        • Geometry (4)
        • LiDAR 3D Detection (1)
        • Trajectory prediction (2)
        • Lane Detection (1)
        • HDmap (3)
      • Linux (15)
      • PyTorch👩🏻‍💻 (10)
      • Linear Algebra (2)
      • Python (5)
      • NLP (10)
        • Article 📑 (1)
      • Algorithms 💻 (22)
        • Basic (8)
        • BAEKJOON (8)
        • Programmers (2)
      • ML (1)
        • 통계적 머신러닝(20-2) (1)
      • SQL (3)
      • 기초금융 💵 (1)
  • 블로그 메뉴

    • 홈
    • About me
  • 링크

    • 나의 소박한 github
    • Naver 블로그
  • 공지사항

  • 인기 글

  • 태그

    CV
    Python
    Linux
    conda
    dfs
    pytorch
    domain adaptation
    NLP
    백준
    CL
    continual learning
    리눅스
    알고리즘
    domain generalization
    til
    자료구조
    코딩테스트
    LLM
    Incremental Learning
    백트래킹
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
당니이
[DG] Domain Generalization의 대표 알고리즘을 DomainBed로 알아보자 (+ Code)
상단으로

티스토리툴바