[CV] Adversarial Learning(적대적 학습)이란? + 응용

2022. 4. 24. 13:43·Computer Vision💖/Basic
반응형

오늘은 PointAugment paper를 읽다가 main 원리로 나온 Adversarial learning에 대해 포스팅해보려고 한다. 주로 적대적 학습이라고 하는데, 그냥 내가 이해한 내용을 간단히 정리해보려고 한다. (미팅준비로 인해... 자세한 포스팅은... 미뤄두겠다.) 

 

# Idea

우선 적대적 학습에서, 적대적(adversarial)이란 '서로 대립관계에 있는' 이라는 뜻이다. 흔히 들어봤을 GAN의 속 의미가 Generative adversarial networks 인데, 여기 들어가는 adversarial과 비슷한 의미라고 이해할 수 있다. GAN은 흔히 두개의 네트워크가 경쟁하며 학습하는 모델이라고 잘 알려져 있는데, Discriminator를 잘 속이기 위한 데이터를 Generator가 얼마나 잘 만들 것인지가 중요한 문제이다. 

여기서 중요한 부분은 '잘 속이기 위한'이다. 오늘 소개할 적대적 학습의 방식도 신경망을 '잘 속이기 위한' 방법으로 알려져있기 때문이다. 우리는 오늘 분류기를 속이기 위한 방법을 알아볼 것이다. 

분류기를 속이는 방법이란 그렇게 어렵지 않다.  아래 두 그림을 보자. 

출처 : Reference[0]

왼쪽은 연두와 빨강을 아주 잘 분류한 분류기이다. 이 분류기를 오분류를 위해 조작하려면, 오른쪽처럼 빨간색 point를 결정경계에서 아주 살짝 이동해주면 될 것이다. 하지만 이렇게 어떻게 "살짝" 움직일 것인가는 어려운 문제이다. 왜냐면 우리가 학습시켜야 할 NN은 위 2차원 그래프 처럼 작은 차원의 데이터도 아닐뿐더러, 아주 복잡한 합성곱의 연속이기 때문이다. 조작해야할 weight의 값이 수천개나 될 수 있다. 

따라서 우리는 네트워크의 가중치를 직접 수정하는 것보다, 신경망에 입력하는 이미지를 조정하는 방법으로 적대적 학습(Adversarial network)을 구현할 수 있다. Input image에서 속임수를 쓰는 것이다. 

 

 

# Process 

우선 흔히 Adversarial Network는 다음과 같은 프로세스로 구현된다. 아래 프로세스를 이해하면, paper에서 응용되는 적대적 학습 개념을 이해하기에는 어려움이 없을 것이다. 핵심은 Real image와 조금 변형된 image를 네트워크에 제공해, 변형된 image가 Real image와 가까워지도록 학습한다는 것이다. 

출처 : Reference[0]

위 프로세스를 절차로 설명해보면 다음과 같다. 

  1. 원본 Image와 조금 변형된 이미지를 네트워크에 제공한다. (여기서 변형된 이미지란 픽셀의 일부분을 변형시킬수도 있고, augmentation을 한 이미지일 수도 있다) 
  2. 1번에서 제공한 Image를 네트워크가 예측하도록 하고, Real Image와 얼마나 차이나는지 확인한 후 역전파를 이용해 Real Image에 대한 output과 가까워지도록 만든다. 
  3. 1, 2를 반복한다. 

즉, 변형된 Image에도 잘 대응할 수 있도록 네트워크의 파라미터(가중치)를 조정하는 과정으로 이해하면 될 것 같다. 

 

 

# Application 

그렇다면 이러한 적대적 학습 방식이 어떻게 쓰이는지 간단한 응용을 소개하고 포스팅을 마치려고 한다. 사실 내가 요즘 읽고 있는 PointAugment: an Auto-Augmentation Framework for Point Cloud Classification (CVPR 2020, Oral) 라는 paper이다. 이 paper의 아키텍처를 보면 위 개념적인 내용의 이해가 쉬울 것 같아 가져왔다!

Pointaugment 아키텍처

이 paper는 pointcloud를 augmentation하는 augmentor를 학습시키는 것을 목표로 하는 paper이다. (무려 CVPR에서 Oral까지 받은 아주 아주 잘 써진 paper인 것 같다! 개인적으로 재밌게 읽었다는..!) 따라서 위 아키텍처를 보면 Classifier에 augmented된 sample과 원래의 origin sample이 각각 들어가는 형태임을 알 수 있다. 여기서 augmented된 sample이란 앞서 설명한 변형된 input image에 해당한다고 이해하면 될 것 같다. 

위 Pointaugment의 loss function도 어떻게 augmented된 sample과 원래의 sample의 거리를 좁혀 학습할 것인가?의 문제를 풀기 위해 설정된다. 간단히 소개하면 다음과 같다. 

여기서 L(P')는 augmented sample의 분류기 결과이고, L(P)는 원래의 sample의 분류기 결과라고 보면된다. ('가 붙은 것이 augmented sample이다.) 특히 Classifier loss에서 두 sample의 feature의 유클리디안 거리를 좁히는 부분이 들어가있음을 알 수 있다. 

 


References 

[0] https://medium.com/@jongdae.lim/%EA%B8%B0%EA%B3%84-%ED%95%99%EC%8A%B5-machine-learning-%EB%A8%B8%EC%8B%A0-%EB%9F%AC%EB%8B%9D-%EC%9D%80-%EC%A6%90%EA%B2%81%EB%8B%A4-part-8-d9507cf20352

반응형
저작자표시 (새창열림)

'Computer Vision💖 > Basic' 카테고리의 다른 글

[CV] 이미지들 사이의 관계를 T-SNE plot으로 나타내기  (0) 2023.08.27
[CV] ResNet-18로 특정 Image의 feature 추출하기 (PyTorch)  (0) 2022.05.30
[CV] Self-supervised learning(자기주도학습)과 Contrastive learning - 스스로 학습하는 알고리즘  (4) 2021.07.02
[CV] AlexNet(2012) 논문을 code로 구현 해보자 (Keras, PyTorch)  (0) 2021.06.25
[CV] AlexNet(2012)의 구조와 논문 리뷰  (0) 2021.06.23
'Computer Vision💖/Basic' 카테고리의 다른 글
  • [CV] 이미지들 사이의 관계를 T-SNE plot으로 나타내기
  • [CV] ResNet-18로 특정 Image의 feature 추출하기 (PyTorch)
  • [CV] Self-supervised learning(자기주도학습)과 Contrastive learning - 스스로 학습하는 알고리즘
  • [CV] AlexNet(2012) 논문을 code로 구현 해보자 (Keras, PyTorch)
당니이
당니이
씩씩하게 공부하기 📚💻
  • 당니이
    다은이의 컴퓨터 공부
    당니이
  • 전체
    오늘
    어제
    • 분류 전체보기 (136)
      • Achieved 👩🏻 (14)
        • 생각들 (2)
        • TIL (6)
        • Trial and Error (1)
        • Inspiration ✨ (0)
        • 미국 박사 준비 🎓 (1)
      • Computer Vision💖 (39)
        • Basic (9)
        • Video (5)
        • Continual Learning (7)
        • Generative model (2)
        • Domain (DA & DG) (5)
        • Multimodal (8)
        • Multitask Learning (1)
        • Segmentation (1)
        • Colorization (1)
      • RL 🤖 (1)
      • Autonomous Driving 🚙 (11)
        • Geometry (4)
        • LiDAR 3D Detection (1)
        • Trajectory prediction (2)
        • Lane Detection (1)
        • HDmap (3)
      • Linux (15)
      • PyTorch👩🏻‍💻 (10)
      • Linear Algebra (2)
      • Python (5)
      • NLP (10)
        • Article 📑 (1)
      • Algorithms 💻 (22)
        • Basic (8)
        • BAEKJOON (8)
        • Programmers (2)
      • ML (1)
        • 통계적 머신러닝(20-2) (1)
      • SQL (3)
      • 기초금융 💵 (1)
  • 블로그 메뉴

    • 홈
    • About me
  • 링크

    • 나의 소박한 github
    • Naver 블로그
  • 공지사항

  • 인기 글

  • 태그

    pytorch
    자료구조
    Linux
    알고리즘
    리눅스
    백트래킹
    continual learning
    CL
    NLP
    domain generalization
    dfs
    Python
    conda
    백준
    CV
    til
    LLM
    domain adaptation
    코딩테스트
    Incremental Learning
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
당니이
[CV] Adversarial Learning(적대적 학습)이란? + 응용
상단으로

티스토리툴바