[PyTorch] torchvision model들의 input channel 변경이 안될 때

2021. 8. 8. 23:47·PyTorch👩🏻‍💻
반응형

나를 일주일 째 괴롭혔던 ARM ...  ARM은 기존 input으로 들어가는 x에 context net의 output을 채널에 대해 concat해서 문맥 정보를 추가해주는 형식의 알고리즘이다. 하지만 ARM을 구현하는 과정에서 내가 직면했던 에러는 다음과 같다. 

 

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[11, 6, 448, 448] to have 3 channels, but got 6 channels instead 

ARM Input은 torchvision 내의 resnet50에 들어가는 형식이었는데, 알고리즘상 채널에 대해 concat을 하다보니 기존 resnet50이 받는 채널 3의 input과는 달라질 수 밖에 없었던 것이다.... 그래서 결론부터 말하자면 torchvision 내의 pretrain model인 resnet50의 input 차원을 변경해줘야하는 상황이었다. 

저 에러로 검색을 해보니 가장 많이 나오는 방법이 아래와 같이 채널 수를 단순히 변경해주는 방법이었다. 

net.conv1.in_channels = 4

## conv1의 채널을 4로 바꾼다는 뜻

 

이렇게 변경을 해 모델을 출력해보면 채널이 변경되어 있는 모습이었는데, 정작 code를 실행시키면 여전히 위의 runtime error가 났다. (이미 pretrain 된 모델이기 때문에 당연한 결과이다 ) 

그래서 구글링과 많은 실험끝에 해결한 방법은 바로 nn.Conv2d layer를 pretrain model 에 넣기 전에 추가해주는 것이다. 

first_conv = nn.Conv2d(3, 1, kernel_size, stride, padding) # you could use e.g. a 1x1 kernel
model = pretrained_model()

x = # load data, should have the shape [batch_size, 3, height, width]
out = first_conv(x)
out = pretrained_model(out)

이런식으로 이 전에 레이어를 추가해주니 문제없이 작동했다 ~!~!~!~!~! 도움을 준 파이토치 커뮤니티 한 외국인님께 감사를...

 

https://discuss.pytorch.org/t/runtimeerror-given-groups-1-weight-of-size-64-3-7-7-expected-input-3-1-224-224-to-have-3-channels-but-got-1-channels-instead/30153/26

 

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[3, 1, 224, 224] to have 3 channels, but got 1 channel

Thanks for your response. I am facing this issue in a different context. If I don’t want to change my original images and also don’t have the option to change the architecture (since I am using resnet18 from models), how can I do slicing? Is there an e

discuss.pytorch.org

 

+ 첨언

오늘 교수님과의 랩미팅으로 알게됐는데, 위 conv 레이어 추가는 아래 resnet50의 pretrain된 아키텍처를 망칠 수 있다. 그래서 위 방법은 아래 아키텍처를 손상시켜도 성능이나 분석에 문제가 없을 때 사용해야한다 .. ! 보통은 pretrain와 input channel 수가 다르면 다시 학습을 시키는 방법을 사용한다고 한다. 

반응형
저작자표시

'PyTorch👩🏻‍💻' 카테고리의 다른 글

[PyTorch] Multi-GPU 사용하기 (torch.distributed.launch)  (0) 2022.06.10
[TIL] OpenPCDet 가상환경 세팅하기 (cuda11.1 + spconv)  (1) 2022.06.10
[PyTorch] torch-sparse, torch-scatter, torch-geometric 패키지 install 하기 + 오류 해결 방법  (0) 2022.04.30
[PyTorch] PyTorch Autograd 이젠 공부하자 - pytorch.autograd 총정리하기 (+code)  (0) 2022.02.16
[PyTorch] CUDA 11.2 + RTX3090에 맞는 torch version 세팅하기  (3) 2022.01.30
'PyTorch👩🏻‍💻' 카테고리의 다른 글
  • [TIL] OpenPCDet 가상환경 세팅하기 (cuda11.1 + spconv)
  • [PyTorch] torch-sparse, torch-scatter, torch-geometric 패키지 install 하기 + 오류 해결 방법
  • [PyTorch] PyTorch Autograd 이젠 공부하자 - pytorch.autograd 총정리하기 (+code)
  • [PyTorch] CUDA 11.2 + RTX3090에 맞는 torch version 세팅하기
당니이
당니이
씩씩하게 공부하기 📚💻
  • 당니이
    다은이의 컴퓨터 공부
    당니이
  • 전체
    오늘
    어제
    • 분류 전체보기 (136)
      • Achieved 👩🏻 (14)
        • 생각들 (2)
        • TIL (6)
        • Trial and Error (1)
        • Inspiration ✨ (0)
        • 미국 박사 준비 🎓 (1)
      • Computer Vision💖 (39)
        • Basic (9)
        • Video (5)
        • Continual Learning (7)
        • Generative model (2)
        • Domain (DA & DG) (5)
        • Multimodal (8)
        • Multitask Learning (1)
        • Segmentation (1)
        • Colorization (1)
      • RL 🤖 (1)
      • Autonomous Driving 🚙 (11)
        • Geometry (4)
        • LiDAR 3D Detection (1)
        • Trajectory prediction (2)
        • Lane Detection (1)
        • HDmap (3)
      • Linux (15)
      • PyTorch👩🏻‍💻 (10)
      • Linear Algebra (2)
      • Python (5)
      • NLP (10)
        • Article 📑 (1)
      • Algorithms 💻 (22)
        • Basic (8)
        • BAEKJOON (8)
        • Programmers (2)
      • ML (1)
        • 통계적 머신러닝(20-2) (1)
      • SQL (3)
      • 기초금융 💵 (1)
  • 블로그 메뉴

    • 홈
    • About me
  • 링크

    • 나의 소박한 github
    • Naver 블로그
  • 공지사항

  • 인기 글

  • 태그

    Linux
    자료구조
    til
    NLP
    코딩테스트
    LLM
    conda
    domain generalization
    dfs
    Incremental Learning
    백트래킹
    리눅스
    Python
    CL
    백준
    domain adaptation
    pytorch
    알고리즘
    CV
    continual learning
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
당니이
[PyTorch] torchvision model들의 input channel 변경이 안될 때
상단으로

티스토리툴바