ML | DL/딥러닝 방법론|실습

대규모 모델 학습·추론 최적화 시리즈 3편: Gradient Checkpointing

Leeys 2025. 9. 14. 23:02
반응형

"VRAM이 부족해서 모델을 학습할 수 없다면?"
Gradient Checkpointing은 중간 activation을 저장하지 않고 필요할 때 재계산하여
메모리를 절약하는 강력한 방법입니다.


 

Gradient Checkpointing

 

Gradient Checkpointing이란?

일반적으로 모델 학습 시 Forward Pass에서 나온 모든 중간 activation
Backward Pass 때 gradient 계산에 사용하기 위해 저장합니다.
하지만 이 activation 저장이 메모리의 대부분을 차지합니다.

Gradient Checkpointing은 중간 activation을 저장하지 않고,
Backward Pass 시 필요한 구간만 다시 forward 계산하여 gradient를 구합니다.

즉, 메모리를 희생하지 않고, 대신 연산량(FLOPs)을 조금 더 쓰는 방식입니다.


왜 필요한가?

  • 대형 모델(수억~수십억 파라미터) 학습 시 activation 메모리가 전체 VRAM의 50~70% 차지
  • GPU 메모리 한계 때문에 batch size를 줄이거나 모델 크기를 줄여야 하는 문제 발생
  • Gradient Checkpointing으로 activation 저장량을 크게 줄이면 더 큰 모델 / 더 큰 batch 학습 가능

동작 원리

모델을 여러 구간(segments)으로 나눈 뒤,
각 구간의 마지막 output만 저장하고 나머지 activation은 버립니다.

Backward 시:

  • 저장된 output부터 해당 구간만 다시 forward pass 실행
  • 새로 계산한 activation으로 gradient 계산 진행

PyTorch 예시 코드

 
import torch
from torch.utils.checkpoint import checkpoint

def custom_forward(module, x):
    return module(x)

model = get_large_model()
x = torch.randn(16, 3, 224, 224).cuda()

# 예시: 특정 블록에 checkpoint 적용
out = checkpoint(custom_forward, model.block1, x)

torch.utils.checkpoint를 사용하면 자동으로 해당 구간 activation 저장 안 하고, backward 시 재계산 수행


장점

1. 메모리 사용량 대폭 감소 (최대 50% ↓)
2. 더 큰 모델 / batch size 학습 가능
3. OOM 방지 → 학습 안정성 증가


단점 & 고려사항

1. 연산량 증가 (forward pass 재계산 필요 → 약 20~30% 학습 시간 증가)
2. 모든 레이어에 적용하면 속도 저하 심해짐 → 중요한 구간에만 선택적으로 적용
3. RNN 등 stateful 연산은 checkpoint 적용 어려움


실제 효과 (벤치마크)

모델 메모리 사용량 감소 속도
GPT-2 1.5B -45% 1.3× 느려짐
BERT-Large -40% 1.2× 느려짐

메모리 절약 효과가 크기 때문에, 메모리 부족 → 속도 조금 손해 트레이드오프에서 대부분 유리


실무 적용 팁

  • Transformer 모델: Attention + FFN block 단위로 checkpoint 적용
  • AMP, Gradient Accumulation과 함께 쓰면 메모리 효율 극대화
  • 중요한 건 적절한 segment size → 너무 잘게 쪼개면 오히려 속도 손해 큼

결론

Gradient Checkpointing은 VRAM 부족 문제를 해결하는 필수 기법입니다.
조금 느려지더라도 더 큰 모델을 학습할 수 있고,
OOM 없이 안정적인 실험이 가능해집니다.

 

 

 

다음 편은 아래에 있습니다!

https://machineindeep.tistory.com/75

 

대규모 모델 학습·추론 최적화 시리즈 4편: FSDP · ZeRO · DeepSpeed

"단일 GPU로는 불가능한 대형 모델, 어떻게 학습할까?"FSDP, ZeRO, DeepSpeed는 모델 파라미터와 optimizer state를 여러 GPU에 분산해거대한 모델을 효율적으로 학습할 수 있도록 해줍니다.문제 정의: 대규

machineindeep.tistory.com

 

반응형