스터디

[8회차] torch.jit.trace vs. torch.jit.script (vs. torch.compile)

apark 2025. 6. 8. 23:41

요즈음 huggingface에는 모델이 대개 *.safetensor로 업로드되어 있지만 예전에는 *.bin이나 *.pt 파일을 종종 볼 수 있었다.
*.pt랑 *.bin은 pytorch 모델을 학습시키는 과정에서 `torch.save()` 명령어를 사용해 저장한 결과 파일이다.

이렇게 바뀐 이유는 성능, 안정성, 최적화 등등 여러 이유 때문이지만......

현재 맡고 있는 업무에서 python이 아닌 다른 프로그래밍 언어 (e.g. C++, java 등) 에서도 로드할 수 있도록 변환할 일이 종종 있는데, 이때 어떤 모델은 되고 어떤 모델은 안 되는 경우가 꽤 많았다. (torch.jit.trace를 썼을 때 torchscript를 쓰라고 에러가 난다거나)

모델을 저장하거나 배포하거나, 혹은 실행 속도를 높이기 위해 사용되지만, 내부 동작 방식은 조금씩 다를 테니 이참에 간략하게 정리를 해 두려고 한다. 

 

 

 

저장 포맷


`*.pt`와 `*.bin`

`*.pt`와 `*.bin`은 PyTorch 모델을 `torch.save`를 사용해 저장한 결과이다.

내부적으로는 Python의 `pickle`을 사용하여 모델 구조와 가중치를 직렬화하는데, 빠르고 간편하지만

코드 의존성이 있어서 불완전한 구조로 저장할 수 있고 보안 이슈가 있다고 한다.

그래서 나온 것이 safetensor인 듯하다.


safetensors

Hugging Face에서 주도한 새로운 저장 포맷이다. 앞에서도 언급했듯이 요즈음은 거의 다 safetensor 포맷으로 저장되어 있다.

pickle을 사용하지 않아서 보안이 향상되었고, 메모리 매핑 (mmap) 기반으로 로드 속도가 조금 향상된 것처럼 보인다. 

실제로 예전과는 다르게 [1/3], [2/3] 이런식으로 분할해서 다운로드 받는 것 같고...

그래서인지는 모르겠지만 멀티 GPU 환경에서 분산하기가 수월해졌다고 한다.

모델 구조가 아닌 가중치만 저장하는 방식이라 실행을 위해서는 여전히 모델 정의 코드가 필요하다.




TorchScript란 무엇인가?

TorchScript는 PyTorch 모델을 정적 그래프 형태로 변환한 것이다.

일반적인 Python 모델은 동적 실행 그래프(Dynamic Computational Graph)를 사용하는 반면,
TorchScript는 정적 그래프 기반이라서 (=입력 형태부터 모델 구조까지 그대로 고정) 모바일 환경이나 C++에서도 모델을 로드해서 사용할 수가 있다.

이때 TorchScript를 생성하는 데에는 2가지 방법이 있는데, 이것이 바로 torch.jit.trace와 torch.jit.script이다. 



1. `torch.jit.trace(fn, example_inputs)`

- 입력값을 주고 실행 흐름을 추적해 연산 그래프를 생성
- 제어 흐름(if, loop 등)은 입력값에 따라 고정됨

(그렇다 보니 모델 forward() 중간에 if문이 섞여 있거나 하면 warning이 뜬다.

혹은 모델 변환 후 처음에 input ids 길이보다 더 긴 입력을 넣어주면 에러가 발생하기도......) 

 

사용법

from transformers import AutoModel, AutoTokenizer

path = "사용할 모델 경로"
model = AutoModel.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)

text = "hi"
encoding = tokenizer(text, return_tensors='pt')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']

traced_model = torch.jit.trace(model, [input_ids, attention_mask])

2. `@torch.jit.script`

- 함수나 모듈 전체를 TorchScript로 분석하여 변환

- 데코레이터 방식으로 동작
- 제어 흐름 등 복잡한 동적 요소 지원 -> 단, 변환 실패 가능성 있고 디버깅이 어려움.


반면 torch.compile은 pytorch 2.0부터 생긴 새로운 기능이다.

여기서부터는 사용해 본 적이 없긴 하지만 앞으로 쓸 일이 생길지도 모르니 정리해 둔다.

모델의 실행 속도를 자동으로 최적화하는 컴파일러이고, 내부적으로 다음 세 가지 기술을 적용해 동작한다고 한다.

TorchDynamo: 파이썬 프레임 추적기를 활용해 연산 그래프 캡처
AOTAutograd: 자동 미분 로직을 ahead-of-time 방식으로 변환
Backend Compiler (nvFuser 등): 최적화된 커널 생성 및 실행

 

사용 예시

import torch
model = MyModel()
compiled_model = torch.compile(model)
output = compiled_model(input_tensor)

 

 

사용자 코드 수정 없이 성능 향상이 가능하고, torch.jit.trace/torch.jit.script보다 제약이 적으면서 동적 모델에 더 적합하다고 한다.

하지만 아무래도 도입된 지 얼마 되지 않았다 보니 아직 지원하지 않는 연산도 있다고 하고...... 당장 사용하기에는 어려울 듯하다. 


적당히 요약하자면...

목적 추천 방법
모델을 저장 및 복원 torch.save(), torch.load() -> *.bin, *.pt
안전하고 빠른 로딩 safetensors
모바일/C++ 배포 TorchScript
최대한의 성능 최적화 (과연)  torch.compile

 


PyTorch는 처음 등장할 때부터 "동적 그래프"의 자유도를 장점으로 내세운 프레임워크이다. (torch 논문 발췌)

하지만 이제는 다양한 운영 환경에 맞춰 여러 방식으로 모델을 서빙할 수 있도록 지원하고 있다는 점이 신기하다. (물론 당연하다고도 생각함)