딥러닝 실험을 하다 보면 모델을 저장하고 다시 불러오는 상황이 크게 두 가지로 나뉜다.
1. 하나는 훈련을 중단했다가 이어서(resume) 학습하기 위한 경우,
2. 다른 하나는 학습이 끝난 모델을 배포하거나 다른 코드에서 재사용하기 위한 경우다.
이 두 경우에 모델을 어떻게 저장하고 load하면 좋을지 배워보고, pytorch 내부에서 어떤 일들이 일어나는지 공부해보자.
1. 훈련 resume을 위한 저장: weight + optimizer 상태
훈련을 중단했다가 다시 이어서 학습하려면, 단순히 모델의 weight만 있어서는 충분하지 않다.
optimizer가 내부적으로 가지고 있는 상태(momentum, variance 등)까지 함께 복구되어야 완전히 동일한 학습을 이어갈 수 있기 때문이다.
그래서 resume용 checkpoint는 보통 dict 형태로 저장된다.
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"args": args,
}
torch.save(checkpoint, path)
여기서 중요한 점은:
- model.state_dict()
→ 모델의 모든 weight (parameter + buffer) - optimizer.state_dict()
→ Adam, AdamW 등의 optimizer가 내부적으로 유지하던 상태 - epoch, args
→ 학습을 정확히 이어가기 위한 메타 정보
이런 형태로 저장해야만, 나중에 resume 시 weight뿐 아니라 optimizer의 “학습 관성”까지 그대로 복구할 수 있다.
2. 배포 또는 재사용을 위한 저장: weight만 저장
반면, 학습이 끝난 모델을 배포하거나, 다른 프로젝트에서 encoder만 가져다 쓰는 경우에는
optimizer 상태는 전혀 필요하지 않다.
이 경우에는 보통 weight만 저장한다.
torch.save(model.state_dict(), path)
이렇게 저장된 파일은 순수하게 모델의 파라미터 값들만 담고 있으며,
- inference
- fine-tuning
- encoder 재사용
같은 용도에 적합하다.
3. PyTorch는 weight를 어떻게 관리해서 저장할까?
여기서 한 가지 중요한 개념이 있다.
PyTorch는 모델의 weight를 “이름(key) → tensor(value)” 형태로 관리한다.
nn.Module을 상속한 모델에서,
self.patch_embed = PatchEmbed(...)
self.blocks = nn.ModuleList([...])
처럼 self.xxx = ...로 할당되는 순간,
PyTorch는 이 attribute 이름(xxx)을 기준으로 내부 registry에 파라미터를 등록한다.
그 결과:
model.state_dict()
를 호출하면 내부적으로 다음과 같은 dict가 만들어진다.
{
"patch_embed.proj.weight": ...,
"patch_embed.proj.bias": ...,
"blocks.0.attn.qkv.weight": ...,
...
}
즉, weight의 이름은 self.{name} 구조를 그대로 따라간다.
또한, forward에서 실제로 사용되는지 여부와는 무관하게,
nn.Module에 등록된 파라미터라면 전부 state_dict에 포함된다.
4. 모델 load
- load는 기본적으로 밑과 같이 weight, optimizer 각각 load해줘야한다.
# 1. 파일 읽기 (딕셔너리 형태로 반환됨)
checkpoint = torch.load('checkpoint.pth')
# 2. 모델 가중치 덮어씌우기
model.load_state_dict(checkpoint['model_state_dict'])
# 3. 옵티마이저 상태(momentum 등) 덮어씌우기
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# 4. (선택) 에포크 정보 등 복구
start_epoch = checkpoint['epoch']
- 보통 optimizer나 epoch는 별 신경 안써도 의도한대로 잘 load가 될 것이다.
- 이제 model load에 초점을 맞춰서 알면 좋은 지식들을 배워보자.
모델을 로드할 때 PyTorch가 보는 기준은 단 하나다.
state_dict의 key 이름과 tensor shape이 현재 모델과 정확히 일치하는가?
기본 로딩 방식은 다음과 같다.
model.load_state_dict(torch.load(path))
이때는:
- 저장된 모델과
- 로드할 모델의 구조, 이름, 차원
이 완전히 동일해야 한다.
조금이라도 다르면 에러가 난다.
예를 들어,
- checkpoint에는 a.weight가 있는데
- 현재 모델에는 a 레이어가 없다면
에러가 발생한다.
하지만, 그러면 우리가 만약 (encoder-decoder)로 먼저 학습하고, encoder만 따로 떼서 이용하고 싶으면 어떻게 해야할까?
5. strict=False: 있는 것만 로드하기
이럴 때 사용할 수 있는 옵션이 바로 strict=False다.
# 2. 모델 가중치 덮어씌우기
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
이 옵션을 사용하면:
- 이름과 shape가 모두 맞는 weight만 로드
- checkpoint에만 있는 weight → 무시
- 현재 모델에만 있는 weight → 랜덤 초기화 유지
이 규칙대로 작동한다.
대신 PyTorch는 다음과 같은 정보를 출력한다.
Missing keys: [...]
Unexpected keys: [...]
에러는 발생하지 않는다.
끝~~~~
'AI 기본 지식' 카테고리의 다른 글
| Categorization of Self-Supervised Learning Methods for Computer Vision (0) | 2026.01.29 |
|---|---|
| cosine similarity vs L2 (0) | 2026.01.26 |
| Pytorch로 딥러닝 훈련할 때 weight가 업데이트 되는 process (0) | 2026.01.21 |
| invariant vs equivariant (0) | 2026.01.20 |
| High-level Server Architecture for LLM Inference (0) | 2025.12.15 |