본문 바로가기

dev_AI_framework

backward 까지 capture 하려면

“학습 전체 캡처(one-shot)”를 하려면?

  • 현재 conv2d.backward / gemm.backward가 출력(gX/gW/gB 등)을 내부에서 생성합니다.
  • CUDA Graph 캡처 중엔 새 GPU 할당이 금지되는 게 안전해요.
  • 따라서 시그니처를 확장해서 out= 파라미터로 gX/gW/gB 버퍼를 사전 할당(아레나) 후 넘겨서, backward가 해당 버퍼에 직접 써주는 형태로 바꿔야 합니다. 그러면 forward+loss(+gY 계산)+backward+update까지 완전 캡처 가능합니다.