[PyTorch] 데코레이터랑 친해지기 @torch.jit.script / @torch.no_grad..

2023. 7. 29. 17:23·PyTorch👩🏻‍💻
반응형

 

한달에 한번은 블로그 포스팅 하려고 하는데.. 절대 의무감에 하는 포스팅은 아니고 암튼 ㅎㅎ.. 

PyTorch 코드를 보다보면 @가 붙어있는 데코레이터를 볼 수 있다. 매번 흐린눈 하기에는 요즘 너무 자주 보이는 것 같아서 더이상 외면하지 않으려고 포스팅을 준비했다. 대표적은 PyTorch 데코레이터에 대해 포스팅하려고 한다. 다음에 기회가 되면 python 데코레이터도 포스팅 하는걸로.. 


@torch.jit.script 

@torch.jit.script는 pytorch의 데코레이터 중 하나로 이 데코레이터를 사용하여 Python 함수를 TorchScript로 변환할 수 있다. 

여기서 TorchScript는 PyTorch의 JIT(Just-In-Time) 컴파일러를 통해 Python 코드를 최적화된 실행 그래프로 변환하는 방법이다. "최적화"를 하기 때문에 말 그대로 모델의 실행속도를 향상시킬 수 있다. 

다만 이 데코레이터를 특정 함수에 사용하려면, 그 함수의 코드는 PyTorch로 이루어져야 한다는걸 기억하자. 아래와 같이 사용하면 된다.

import torch

@torch.jit.script
def add(a, b):
    return a + b

x = torch.tensor(2)
y = torch.tensor(3)
result = add(x, y)
print(result)

위 @torch.jit.script 데코레이터를 add에 사용함으로써, Torchscript로 변환되었고 함수의 실행속도를 더 빠르게 만들 수 있다. 하지만 모든 상황에서 사용할 수 있는 것은 아니고, 때로는 @torch.jit.trace 같은 데코레이터를 사용하는게 더 나을 때도 있다. 


@torch.no_grad()

가장 흔한 국민 데코레이터이다. 이 데코레이터를 통해 자동으로 해당 함수 내의 모든 gradient 계산을 멈춘다. 즉, inference 단계에서 많이 이용하며 이 때는 자연스럽게 모델의 파라미터가 업데이트 되지 않는다. 아래와 같이 사용한다.

x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
    y = x * 2
y.requires_grad
# >> FALSE

@torch.no_grad()
def doubler(x):
    return x * 2
z = doubler(x)
z.requires_grad
# >> FALSE

그렇다면 @torch.no_grad() 데코레이터와 우리가 흔히 사용하는 with torch.no_grad(): 의 차이점은 무엇일까? (갑자기 필자가 궁금해졌다) 정답은 데코레이터는 해당 함수 내의 모든 gradient 연산을 멈추고, with torch.no_grad()는 해당하는 block의 gradient 연산만 멈춘다는 것이다. 아래와 같이 사용할 수 있다. 

import torch

def my_evaluation_function(input_data):
    # 이 함수 내에서는 그래디언트 계산이 활성화되어 모델 파라미터가 업데이트됩니다.
    model = MyModel()
    with torch.no_grad():
        # 이 블록 내에서 모든 연산은 그래디언트 계산이 비활성화됩니다.
        # 따라서 모델 파라미터는 업데이트되지 않습니다.
        output = model(input_data)
    return output

 

반응형
저작자표시 (새창열림)

'PyTorch👩🏻‍💻' 카테고리의 다른 글

[PyTorch] 모델 efficiency 측정하기 (used gpu memory / parameter 개수 / Inference time)  (0) 2023.06.06
[PyTorch] mmcv 설치하기 / cuda 버전에 맞게 mmcv downgrade하기 / mmcv._ext error 해결  (5) 2023.05.19
[PyTorch] nvcc가 안될 때 ~/.bashrc 수정해 환경변수 설정하기  (1) 2023.03.27
[PyTorch] Multi-GPU 사용하기 (torch.distributed.launch)  (0) 2022.06.10
[TIL] OpenPCDet 가상환경 세팅하기 (cuda11.1 + spconv)  (1) 2022.06.10
'PyTorch👩🏻‍💻' 카테고리의 다른 글
  • [PyTorch] 모델 efficiency 측정하기 (used gpu memory / parameter 개수 / Inference time)
  • [PyTorch] mmcv 설치하기 / cuda 버전에 맞게 mmcv downgrade하기 / mmcv._ext error 해결
  • [PyTorch] nvcc가 안될 때 ~/.bashrc 수정해 환경변수 설정하기
  • [PyTorch] Multi-GPU 사용하기 (torch.distributed.launch)
당니이
당니이
씩씩하게 공부하기 📚💻
  • 당니이
    다은이의 컴퓨터 공부
    당니이
  • 전체
    오늘
    어제
    • 분류 전체보기 (136)
      • Achieved 👩🏻 (14)
        • 생각들 (2)
        • TIL (6)
        • Trial and Error (1)
        • Inspiration ✨ (0)
        • 미국 박사 준비 🎓 (1)
      • Computer Vision💖 (39)
        • Basic (9)
        • Video (5)
        • Continual Learning (7)
        • Generative model (2)
        • Domain (DA & DG) (5)
        • Multimodal (8)
        • Multitask Learning (1)
        • Segmentation (1)
        • Colorization (1)
      • RL 🤖 (1)
      • Autonomous Driving 🚙 (11)
        • Geometry (4)
        • LiDAR 3D Detection (1)
        • Trajectory prediction (2)
        • Lane Detection (1)
        • HDmap (3)
      • Linux (15)
      • PyTorch👩🏻‍💻 (10)
      • Linear Algebra (2)
      • Python (5)
      • NLP (10)
        • Article 📑 (1)
      • Algorithms 💻 (22)
        • Basic (8)
        • BAEKJOON (8)
        • Programmers (2)
      • ML (1)
        • 통계적 머신러닝(20-2) (1)
      • SQL (3)
      • 기초금융 💵 (1)
  • 블로그 메뉴

    • 홈
    • About me
  • 링크

    • 나의 소박한 github
    • Naver 블로그
  • 공지사항

  • 인기 글

  • 태그

    conda
    Python
    til
    domain generalization
    백준
    NLP
    백트래킹
    Linux
    알고리즘
    continual learning
    CV
    코딩테스트
    자료구조
    Incremental Learning
    리눅스
    LLM
    CL
    pytorch
    dfs
    domain adaptation
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
당니이
[PyTorch] 데코레이터랑 친해지기 @torch.jit.script / @torch.no_grad..
상단으로

티스토리툴바