
나를 일주일 째 괴롭혔던 ARM ... ARM은 기존 input으로 들어가는 x에 context net의 output을 채널에 대해 concat해서 문맥 정보를 추가해주는 형식의 알고리즘이다. 하지만 ARM을 구현하는 과정에서 내가 직면했던 에러는 다음과 같다.
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[11, 6, 448, 448] to have 3 channels, but got 6 channels instead
ARM Input은 torchvision 내의 resnet50에 들어가는 형식이었는데, 알고리즘상 채널에 대해 concat을 하다보니 기존 resnet50이 받는 채널 3의 input과는 달라질 수 밖에 없었던 것이다.... 그래서 결론부터 말하자면 torchvision 내의 pretrain model인 resnet50의 input 차원을 변경해줘야하는 상황이었다.
저 에러로 검색을 해보니 가장 많이 나오는 방법이 아래와 같이 채널 수를 단순히 변경해주는 방법이었다.
net.conv1.in_channels = 4
## conv1의 채널을 4로 바꾼다는 뜻
이렇게 변경을 해 모델을 출력해보면 채널이 변경되어 있는 모습이었는데, 정작 code를 실행시키면 여전히 위의 runtime error가 났다. (이미 pretrain 된 모델이기 때문에 당연한 결과이다 )
그래서 구글링과 많은 실험끝에 해결한 방법은 바로 nn.Conv2d layer를 pretrain model 에 넣기 전에 추가해주는 것이다.
first_conv = nn.Conv2d(3, 1, kernel_size, stride, padding) # you could use e.g. a 1x1 kernel
model = pretrained_model()
x = # load data, should have the shape [batch_size, 3, height, width]
out = first_conv(x)
out = pretrained_model(out)이런식으로 이 전에 레이어를 추가해주니 문제없이 작동했다 ~!~!~!~!~! 도움을 준 파이토치 커뮤니티 한 외국인님께 감사를...
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[3, 1, 224, 224] to have 3 channels, but got 1 channel
Thanks for your response. I am facing this issue in a different context. If I don’t want to change my original images and also don’t have the option to change the architecture (since I am using resnet18 from models), how can I do slicing? Is there an e
discuss.pytorch.org
+ 첨언
오늘 교수님과의 랩미팅으로 알게됐는데, 위 conv 레이어 추가는 아래 resnet50의 pretrain된 아키텍처를 망칠 수 있다. 그래서 위 방법은 아래 아키텍처를 손상시켜도 성능이나 분석에 문제가 없을 때 사용해야한다 .. ! 보통은 pretrain와 input channel 수가 다르면 다시 학습을 시키는 방법을 사용한다고 한다.
'PyTorch👩🏻💻' 카테고리의 다른 글
| [PyTorch] Multi-GPU 사용하기 (torch.distributed.launch) (0) | 2022.06.10 | 
|---|---|
| [TIL] OpenPCDet 가상환경 세팅하기 (cuda11.1 + spconv) (1) | 2022.06.10 | 
| [PyTorch] torch-sparse, torch-scatter, torch-geometric 패키지 install 하기 + 오류 해결 방법 (0) | 2022.04.30 | 
| [PyTorch] PyTorch Autograd 이젠 공부하자 - pytorch.autograd 총정리하기 (+code) (0) | 2022.02.16 | 
| [PyTorch] CUDA 11.2 + RTX3090에 맞는 torch version 세팅하기 (3) | 2022.01.30 | 
![[PyTorch] torchvision model들의 input channel 변경이 안될 때](https://img1.daumcdn.net/thumb/R750x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdna%2FcxjUmT%2FbtrAKwDqUMO%2FAAAAAAAAAAAAAAAAAAAAANMPLdU8q8hGaGb4tuQdSm_G9OkWam_ncft6OCK8ltKO%2Fimg.png%3Fcredential%3DyqXZFxpELC7KVnFOS48ylbz2pIh7yKj8%26expires%3D1761922799%26allow_ip%3D%26allow_referer%3D%26signature%3D6rHzzzlJCMMzkQvabt0%252F3UbdK0w%253D)