오늘은 논문용 그림을 그리다가.. T-SNE plotting에 대해 나중에 또 쓸 일이 있을 것 같아서 간단히 정리해놓는다! 😎
여러 공모전 참여에서 얻은 overfitting의 상처로.. 연구를 처음 배울 때부터 나는 domain shift에 관심이 많았다. DG/DA paper들에서 꼭 보이는 plot이 T-SNE plot 인데, 데이터들 사이에 domain shift를 보여주기에 딱이다.
이 포스팅에서는 ResNet18에서 얻은 feature들을 T-SNE으로 차원축소해, 이들이 얼마나 떨어져있는지 plotting 하는 방법을 다룬다.
[1] Pretrained ResNet18 Setting
먼저 PyTorch 내장된 ResNet을 불러온다. ResNet18은 특히 Input size가 (224, 224)여야 하므로, normalization이 포함된 transformation step이 아래와 같이 필요함을 유의하자.
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image
# Load the pretrained model
model = models.resnet18(pretrained=True)
# Use the model object to select the desired layer
layer = model._modules.get('avgpool')
# Set model to evaluation mode
model.eval()
# Image transforms
scaler = transforms.Scale((224, 224)) # resnet input
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()
[2] Extract features from ResNet
두 데이터셋 image의 경로를 all 이라는 리스트에 담아주면 된다. 바로 PIL로 이미지를 열고, 위 pretrained ResNet에 통과시켜서 output feature를 적재한다.
from tqdm import tqdm
tu_features = []
for image_name in tqdm(all) :
img = Image.open(image_name)
t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))
output = list(model(t_img).squeeze().detach().numpy())
# np.concatenate(output, )
tu_features.append(output)
[3] T-SNE train
sklearn의 T-SNE을 이용해 2차원으로 위 feature space를 축소한다.
from sklearn.manifold import TSNE
features = np.array(tu_features)
tsne = TSNE(n_components=2).fit_transform(features) # tsne
[4] Plotting
위에서 나온 T-SNE feature들을 scaling 후 scatterplot으로 plotting 한다.
def scale_to_01_range(x): # scaling
value_range = (np.max(x) - np.min(x))
starts_from_zero = x - np.min(x)
return starts_from_zero / value_range
tx = tsne[:, 0]
ty = tsne[:, 1]
tx = scale_to_01_range(tx)
ty = scale_to_01_range(ty)
plt.scatter(tx, ty, color='skyblue', label='CULane')
plt.legend()
plt.show()
그럼 아래와 같은 이미지를 얻는다!
반응형
'Computer Vision💖 > Basic' 카테고리의 다른 글
[CV] Hidden dimension이 너무 클 때 flatten 하지 말고 똑똑하게 layer 추가하기 (0) | 2023.08.30 |
---|---|
[CV] ResNet-18로 특정 Image의 feature 추출하기 (PyTorch) (0) | 2022.05.30 |
[CV] Adversarial Learning(적대적 학습)이란? + 응용 (0) | 2022.04.24 |
[CV] Self-supervised learning(자기주도학습)과 Contrastive learning - 스스로 학습하는 알고리즘 (4) | 2021.07.02 |
[CV] AlexNet(2012) 논문을 code로 구현 해보자 (Keras, PyTorch) (0) | 2021.06.25 |