이번 포스팅에서는 저번 포스팅에서 다루었던 CLIP 논문의 Experiment를 중심으로 포스팅하겠습니다. 특히 이번 게시글에서는 zero shot learning과 representation learning에 관해 포스팅할텐데요, 역시 잘못된 점이 있다면 댓글로 알려주시면 감사하겠습니다. 👀
CLIP의 전반적인 구조
우선 CLIP의 전반적인 구조는 다음과 같습니다. 등장 배경 및 자세한 원리를 알고싶으시다면 제 이전 게시글을 참고 부탁드립니다!
CLIP은 이미지와 텍스트 쌍을 input으로 부여하고, 이러한 가능한 쌍을 예측하도록 학습됩니다. 만약 실제 (이미지, 텍스트) 쌍이라면 이들의 코사인 유사도를 최대화 하고, 나머지 쌍들은 코사인 유사도를 최소화하는 방향으로 학습하는 것입니다. 이러한 과정은 multi-modal 임베딩 공간을 학습하게 됩니다. 그 과정에서 이미지 인코더는 ResNet과 ViT, 텍스트 인코더는 Transformer를 사용합니다.
본 고에서는 이러한 원리의 CLIP을 3가지 챕터로 나누어 Experiment 한 바를 기록하고 있습니다. 이 세가지는 다음과 같습니다.
1. Zero-shot transfer 2. Representation Learning 3. Robustness to Natural Distribution Shift |
저는 이 세가지 챕터 중에서 1, 2번은 그냥 단순한 원리와 결과만을 짚고 넘어가고, 3번째 챕터에 집중해 포스팅하도록 하겠습니다. 그럼 1번 제로샷 트렌스퍼부터 알아보겠습니다.
Zero-shot Transfer
우선 Zero-shot transfer에 들어가기 앞서 transfer learning과 zero-shot에 대한 이해가 선행되어야할 것 같아 관련 개념을 먼저 포스팅하고자 합니다. 아래 그림을 보시면 이해가 쉬우실 것 같습니다.
우선 Transfer learning(전이학습, 이전학습)은 특정 데이터로 이미 학습된 모델을 다른 태스크에 재사용하는 기법을 가리킵니다. 위 그림처럼 Task2를 수행하는 모델을 만들고자 할 때, Data 1 으로 Task1을 수행하도록 미리 학습된 모델을 이용하는 것입니다. 여기서 Task1을 업스트림(Upstream) 태스크, Task2를 다운스트림(Down stream) 태스크라고 합니다. 그리고 업스트림 태스크를 학습하는 파란색 과정을 Pre-train이라고 합니다. CLIP은 Pretrain 모델이었죠. 따라서 Pretrain은 결론적으로 다운스트림 태스크를 잘 수행하기 위한 전단계라고 생각할 수 있습니다.
그럼 다운스트림 태스크를 학습하는 과정은 무엇이라고 부를까요? 이 과정이 fine tuning, zero-shot learning, one-shot learning, few-shot learning 등으로 나뉘는 것입니다. 각각의 개념을 살펴보겠습니다.
- 파인튜닝(Fine tuning) : 다운스트림 태스크에 해당하는 데이터 전체(그림에서는 data2) 를 모두 사용
- 제로샷 러닝(Zero-shot learning) : 다운스트림 태스크의 데이터를 전혀 사용하지 않고 pretrain 모델로 다운스트림 태스크를 바로 수행
- 원샷 러닝(One-shot learning) : 다운스트림 태스크의 데이터를 한 건만 사용해 어떻게 수행되는지 참고한 뒤 바로 다운스트림 태스크 수행
- 퓨샷 러닝(Few-shot learning) : 다운스트림 태스크의 데이터를 몇 건만 사용하고, pre-train 모델을 몇개의 데이터에 맞게 업데
여기서 Zero-shot 이란 한번도 보지 못한 데이터에 대한 라벨을 예측하는 방식을 말합니다. 따라서 본 고에서는 한번도 보지 못한 데이터셋을 이용해 제로샷 성능을 측정합니다.
위 그림은 제로샷 러닝 과정을 도식화한 것입니다. 우선 plane, car 과 같은 라벨에서 'A photo of a 라벨' 의 문장을 만든 후, 토큰화합니다. 그리고 Text encoder(transformer) 를 이용해 텍스트 인코딩을 생성하고, 이미지도 이미지 인코딩을 생성합니다. 그 후, 한 이미지와 라벨의 정보가 담긴 텍스트 인코딩 사이 코사인 유사도를 구해, 가장 유사도가 큰 값을 라벨로 반환하는 형식입니다. 논문에서는 텍스트 인코딩 정보가 weight 를 부여하고 있다고 칭하며, hypernetwork의 기능을 하고 있다고 합니다.
대략적인 코드는 다음과 같습니다.
import os
import clip
import torch
from torchvision.datasets import CIFAR100
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
이러한 CLIP을 이용한 제로샷 러닝의 방식은 뛰어난 성과를 가져왔습니다. 다음 그래프를 보시죠.
위 좌측의 그래프와 같이 우선 CLIP으로 인한 제로샷 학습은 supervised baseline보다 27개의 데이터셋 중 16개의 데이터셋에서 우수한 성능을 보이고 있음을 알 수 있습니다. 또한 우측의 그래프는 제로샷 학습이 퓨 샷 학습보다 더 나은 성능을 보임을 나타냅니다.
Representation Learning
다음은 Representation learning 입니다. Representation learning(표현학습)은 이미지의 특성을 가장 잘 나타낼 수 있는 특성을 최대한 잘 뽑아 이를 down stream task에 이용하겠다는 방법론인데요, CLIP도 이러한 표현학습의 방법론으로 사용될 수 있습니다.
이러한 Representation learning의 성능은 representation 된 feature를 선형모델에 넣은 성능으로 평가하는데요, 본 고에서도 이러한 방법을 이용했습니다. 특히 선형모델을 선택한 이유는 flexibility 가 제한되어있기 때문입니다. 이런 과정을 거친 후 state of arts 모델과 비교했습니다. 선형모델로 representation learning 의 성능을 비교하는 방법은 다음 code를 보시면 이해가 쉽습니다.
import os
import clip
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)
def get_features(dataset):
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
features = model.encode_image(images.to(device))
all_features.append(features)
all_labels.append(labels)
return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)
# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)
# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")
get_features 라는 함수로 features를 추출하고, 로지스틱 선형모델을 사용하여 성능을 측정하고 있는 모습입니다. 이러한 결과를 다른 모델과 비교해보면 다음과 같았다고 합니다.
위 그래프를 보시다시피 CLIP-ViT 가 가장 높은 성능을 보이고 있음을 확인할 수 있습니다. CLIP이 representation learning 에서도 매우 우수한 성능을 보이고 있음을 알 수 있습니다
지금까지 제로샷러닝, 표현학습에서의 CLIP의 성능에 대해 알아보았습니다. 다음 포스팅에서는 Domain Generalization에서의 CLIP의 사용에 관해 포스팅하겠습니다.
Reference
[0] https://github.com/openai/CLIP/blob/main/notebooks/Interacting_with_CLIP.ipynb
[1] https://ratsgo.github.io/nlpbook/docs/introduction/transfer/
'Computer Vision💖 > Vision + Language' 카테고리의 다른 글
[Multimodal] 멀티모달 러닝 (Multimodal Learning)에 대한 아주 기초적인 이해 (1) | 2024.01.18 |
---|---|
[VQA] Zero-shot VQA + Domain Adaptation VQA 분야 개괄 (0) | 2023.08.01 |
[XAI] Generating Visual Explanations(2016) - 이미지 분류에 대한 설명을 생성하는 알고리즘 (0) | 2021.08.15 |
[XAI] OpenAI CLIP 논문 리뷰[3] - Domain Generalization (2) | 2021.07.19 |
[XAI] OpenAI CLIP 논문 리뷰[1] - 전반적인 아키텍처 (1) | 2021.07.15 |