당니이
다은이의 컴퓨터 공부
당니이
전체 방문자
오늘
어제
  • 분류 전체보기 (140)
    • Achieved 👩🏻 (14)
      • 생각들 (2)
      • TIL (6)
      • Trial and Error (1)
      • Inspiration ✨ (0)
      • 미국 박사 준비 🎓 (1)
    • Computer Vision💖 (39)
      • Basic (9)
      • Video (5)
      • Continual Learning (7)
      • Generative model (2)
      • Domain (DA & DG) (5)
      • Multimodal (8)
      • Multitask Learning (1)
      • Segmentation (1)
      • Colorization (1)
    • RL 🤖 (4)
    • Autonomous Driving 🚙 (11)
      • Geometry (4)
      • LiDAR 3D Detection (1)
      • Trajectory prediction (2)
      • Lane Detection (1)
      • HDmap (3)
    • Linux (15)
    • PyTorch👩🏻‍💻 (10)
    • Linear Algebra (2)
    • Python (5)
    • NLP (11)
      • Article 📑 (1)
    • Algorithms 💻 (22)
      • Basic (8)
      • BAEKJOON (8)
      • Programmers (2)
    • ML (1)
      • 통계적 머신러닝(20-2) (1)
    • SQL (3)
    • 기초금융 💵 (1)

블로그 메뉴

  • 홈
  • About me

공지사항

인기 글

태그

  • pytorch
  • continual learning
  • Linux
  • Incremental Learning
  • CL
  • NLP
  • CV
  • LLM
  • domain adaptation
  • conda
  • 리눅스
  • domain generalization
  • dfs
  • 백준
  • 자료구조
  • 코딩테스트
  • 알고리즘
  • Python
  • til
  • 백트래킹

최근 댓글

최근 글

티스토리

hELLO · Designed By 정상우.
당니이

다은이의 컴퓨터 공부

[PyTorch] torchvision model들의 input channel 변경이 안될 때
PyTorch👩🏻‍💻

[PyTorch] torchvision model들의 input channel 변경이 안될 때

2021. 8. 8. 23:47
반응형

나를 일주일 째 괴롭혔던 ARM ...  ARM은 기존 input으로 들어가는 x에 context net의 output을 채널에 대해 concat해서 문맥 정보를 추가해주는 형식의 알고리즘이다. 하지만 ARM을 구현하는 과정에서 내가 직면했던 에러는 다음과 같다. 

 

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[11, 6, 448, 448] to have 3 channels, but got 6 channels instead 

ARM Input은 torchvision 내의 resnet50에 들어가는 형식이었는데, 알고리즘상 채널에 대해 concat을 하다보니 기존 resnet50이 받는 채널 3의 input과는 달라질 수 밖에 없었던 것이다.... 그래서 결론부터 말하자면 torchvision 내의 pretrain model인 resnet50의 input 차원을 변경해줘야하는 상황이었다. 

저 에러로 검색을 해보니 가장 많이 나오는 방법이 아래와 같이 채널 수를 단순히 변경해주는 방법이었다. 

net.conv1.in_channels = 4

## conv1의 채널을 4로 바꾼다는 뜻

 

이렇게 변경을 해 모델을 출력해보면 채널이 변경되어 있는 모습이었는데, 정작 code를 실행시키면 여전히 위의 runtime error가 났다. (이미 pretrain 된 모델이기 때문에 당연한 결과이다 ) 

그래서 구글링과 많은 실험끝에 해결한 방법은 바로 nn.Conv2d layer를 pretrain model 에 넣기 전에 추가해주는 것이다. 

first_conv = nn.Conv2d(3, 1, kernel_size, stride, padding) # you could use e.g. a 1x1 kernel
model = pretrained_model()

x = # load data, should have the shape [batch_size, 3, height, width]
out = first_conv(x)
out = pretrained_model(out)

이런식으로 이 전에 레이어를 추가해주니 문제없이 작동했다 ~!~!~!~!~! 도움을 준 파이토치 커뮤니티 한 외국인님께 감사를...

 

https://discuss.pytorch.org/t/runtimeerror-given-groups-1-weight-of-size-64-3-7-7-expected-input-3-1-224-224-to-have-3-channels-but-got-1-channels-instead/30153/26

 

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[3, 1, 224, 224] to have 3 channels, but got 1 channel

Thanks for your response. I am facing this issue in a different context. If I don’t want to change my original images and also don’t have the option to change the architecture (since I am using resnet18 from models), how can I do slicing? Is there an e

discuss.pytorch.org

 

+ 첨언

오늘 교수님과의 랩미팅으로 알게됐는데, 위 conv 레이어 추가는 아래 resnet50의 pretrain된 아키텍처를 망칠 수 있다. 그래서 위 방법은 아래 아키텍처를 손상시켜도 성능이나 분석에 문제가 없을 때 사용해야한다 .. ! 보통은 pretrain와 input channel 수가 다르면 다시 학습을 시키는 방법을 사용한다고 한다. 

반응형
저작자표시 (새창열림)

'PyTorch👩🏻‍💻' 카테고리의 다른 글

[PyTorch] Multi-GPU 사용하기 (torch.distributed.launch)  (0) 2022.06.10
[TIL] OpenPCDet 가상환경 세팅하기 (cuda11.1 + spconv)  (1) 2022.06.10
[PyTorch] torch-sparse, torch-scatter, torch-geometric 패키지 install 하기 + 오류 해결 방법  (0) 2022.04.30
[PyTorch] PyTorch Autograd 이젠 공부하자 - pytorch.autograd 총정리하기 (+code)  (0) 2022.02.16
[PyTorch] CUDA 11.2 + RTX3090에 맞는 torch version 세팅하기  (3) 2022.01.30
    'PyTorch👩🏻‍💻' 카테고리의 다른 글
    • [TIL] OpenPCDet 가상환경 세팅하기 (cuda11.1 + spconv)
    • [PyTorch] torch-sparse, torch-scatter, torch-geometric 패키지 install 하기 + 오류 해결 방법
    • [PyTorch] PyTorch Autograd 이젠 공부하자 - pytorch.autograd 총정리하기 (+code)
    • [PyTorch] CUDA 11.2 + RTX3090에 맞는 torch version 세팅하기
    당니이
    당니이
    씩씩하게 공부하기 📚💻

    티스토리툴바