[CV] 이미지들 사이의 관계를 T-SNE plot으로 나타내기

2023. 8. 27. 01:51·Computer Vision💖/Basic
반응형

오늘은 논문용 그림을 그리다가.. T-SNE plotting에 대해 나중에 또 쓸 일이 있을 것 같아서 간단히 정리해놓는다! 😎

여러 공모전 참여에서 얻은 overfitting의 상처로.. 연구를 처음 배울 때부터 나는 domain shift에 관심이 많았다. DG/DA paper들에서 꼭 보이는 plot이 T-SNE plot 인데, 데이터들 사이에 domain shift를 보여주기에 딱이다. 

이 포스팅에서는 ResNet18에서 얻은 feature들을 T-SNE으로 차원축소해, 이들이 얼마나 떨어져있는지 plotting 하는 방법을 다룬다. 


[1] Pretrained ResNet18 Setting

먼저 PyTorch 내장된 ResNet을 불러온다. ResNet18은 특히 Input size가 (224, 224)여야 하므로, normalization이 포함된 transformation step이 아래와 같이 필요함을 유의하자. 

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image

# Load the pretrained model
model = models.resnet18(pretrained=True)

# Use the model object to select the desired layer
layer = model._modules.get('avgpool')

# Set model to evaluation mode
model.eval()

# Image transforms
scaler = transforms.Scale((224, 224))   # resnet input 
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

 

[2] Extract features from ResNet

두 데이터셋 image의 경로를 all 이라는 리스트에 담아주면 된다. 바로 PIL로 이미지를 열고, 위 pretrained ResNet에 통과시켜서 output feature를 적재한다. 

from tqdm import tqdm 
tu_features = []
for image_name in tqdm(all) : 
    img = Image.open(image_name)
    t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))
    output = list(model(t_img).squeeze().detach().numpy())
    # np.concatenate(output, )
    tu_features.append(output)

 

[3] T-SNE train 

sklearn의 T-SNE을 이용해 2차원으로 위 feature space를 축소한다. 

from sklearn.manifold import TSNE
features = np.array(tu_features)
tsne = TSNE(n_components=2).fit_transform(features)    # tsne

 

[4] Plotting 

위에서 나온 T-SNE feature들을 scaling 후 scatterplot으로 plotting 한다. 

def scale_to_01_range(x):    # scaling 
    value_range = (np.max(x) - np.min(x))
    starts_from_zero = x - np.min(x)
    return starts_from_zero / value_range

tx = tsne[:, 0]
ty = tsne[:, 1]
tx = scale_to_01_range(tx)
ty = scale_to_01_range(ty)

plt.scatter(tx, ty, color='skyblue', label='CULane')
plt.legend()
plt.show()

그럼 아래와 같은 이미지를 얻는다! 

반응형
저작자표시 (새창열림)

'Computer Vision💖 > Basic' 카테고리의 다른 글

[CV] Hidden dimension이 너무 클 때 flatten 하지 말고 똑똑하게 layer 추가하기  (0) 2023.08.30
[CV] ResNet-18로 특정 Image의 feature 추출하기 (PyTorch)  (0) 2022.05.30
[CV] Adversarial Learning(적대적 학습)이란? + 응용  (0) 2022.04.24
[CV] Self-supervised learning(자기주도학습)과 Contrastive learning - 스스로 학습하는 알고리즘  (4) 2021.07.02
[CV] AlexNet(2012) 논문을 code로 구현 해보자 (Keras, PyTorch)  (0) 2021.06.25
'Computer Vision💖/Basic' 카테고리의 다른 글
  • [CV] Hidden dimension이 너무 클 때 flatten 하지 말고 똑똑하게 layer 추가하기
  • [CV] ResNet-18로 특정 Image의 feature 추출하기 (PyTorch)
  • [CV] Adversarial Learning(적대적 학습)이란? + 응용
  • [CV] Self-supervised learning(자기주도학습)과 Contrastive learning - 스스로 학습하는 알고리즘
당니이
당니이
씩씩하게 공부하기 📚💻
  • 당니이
    다은이의 컴퓨터 공부
    당니이
  • 전체
    오늘
    어제
    • 분류 전체보기 (136)
      • Achieved 👩🏻 (14)
        • 생각들 (2)
        • TIL (6)
        • Trial and Error (1)
        • Inspiration ✨ (0)
        • 미국 박사 준비 🎓 (1)
      • Computer Vision💖 (39)
        • Basic (9)
        • Video (5)
        • Continual Learning (7)
        • Generative model (2)
        • Domain (DA & DG) (5)
        • Multimodal (8)
        • Multitask Learning (1)
        • Segmentation (1)
        • Colorization (1)
      • RL 🤖 (1)
      • Autonomous Driving 🚙 (11)
        • Geometry (4)
        • LiDAR 3D Detection (1)
        • Trajectory prediction (2)
        • Lane Detection (1)
        • HDmap (3)
      • Linux (15)
      • PyTorch👩🏻‍💻 (10)
      • Linear Algebra (2)
      • Python (5)
      • NLP (10)
        • Article 📑 (1)
      • Algorithms 💻 (22)
        • Basic (8)
        • BAEKJOON (8)
        • Programmers (2)
      • ML (1)
        • 통계적 머신러닝(20-2) (1)
      • SQL (3)
      • 기초금융 💵 (1)
  • 블로그 메뉴

    • 홈
    • About me
  • 링크

    • 나의 소박한 github
    • Naver 블로그
  • 공지사항

  • 인기 글

  • 태그

    LLM
    리눅스
    Python
    CV
    dfs
    domain adaptation
    conda
    백준
    domain generalization
    알고리즘
    코딩테스트
    CL
    Incremental Learning
    continual learning
    자료구조
    NLP
    백트래킹
    pytorch
    til
    Linux
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
당니이
[CV] 이미지들 사이의 관계를 T-SNE plot으로 나타내기
상단으로

티스토리툴바