🎯 현재 파이프라인 구조
CudaGraphTrainer.compile()
├─ make_plan_for_sequential() # 각 layer의 입출력 버퍼/grad 버퍼 계획
├─ try_rebind_grads() # optimizer grad 포인터를 plan 버퍼와 연결
├─ record_step_graph(...) # capture_stream 내부에서
│ ├─ _run_fwd(model, plan, X_buf)
│ ├─ loss_fn.forward(cur, y_buf) → loss_dev
│ ├─ loss_out[...] = loss_dev # ✅ 손실 스칼라를 디바이스 버퍼에 기록
│ ├─ _zero_bwd_buffers(plan)
│ ├─ _run_bwd(model, plan, dY)
│ └─ optimizer_step_fn() # (opt.step + zero_grad)
│
└─ gexec = cap.graph.instantiate() # 그래프 객체 완성
CudaGraphTrainer.one_step()
├─ tg.set_batch(X, y) # 고정 버퍼에 복사
├─ tg.launch() # fwd→loss→bwd→opt 그래프 실행
└─ loss = loss_buf.get() # ✅ 그래프 내에서 갱신된 손실값을 D2H로 읽음
즉,
- forward 패스, 손실 계산, backward, optimizer step, zero_grad
전부 한 번의 CUDA Graph 캡처 스트림 안에서 실행됩니다. - loss_out 버퍼가 그래프 안에서 매 스텝 loss 값을 덮어쓰므로,
호스트에서는 get() 한 번으로 D2H 읽기만 하면 됩니다. - CudaGraphTrainer.one_step()은 추가 fwd 연산이 전혀 없습니다.
🔍 검증 포인트
- smoke 테스트 로그에서이런 식으로 손실이 자연스럽게 감소 → 그래프 내부에서 loss 계산/업데이트 모두 정상 동작.
- 파라미터 업데이트(AdamWOpt) 후 loss_buf의 값이 바뀜 →
- loss_out이 제대로 그래프 내에서 덮어쓰였다는 뜻입니다.
step=00 loss=1.685245
step=01 loss=1.665294
step=02 loss=1.642423
🚀 다음 단계 (선택적)
- GradScaler / AMP 통합
→ loss × scale, overflow 플래그도 같은 방식으로 버퍼에 포함시켜 캡처. - 모델별 고정 입력으로 멀티-그래프 지원
→ 배치/시퀀스 길이가 다른 경우 TrainGraph 인스턴스를 여러 개 보관. - NVTX 범위 태깅
→ loss, bwd, opt 범위별로 시각화.
'dev_AI_framework' 카테고리의 다른 글
| 템플릿화된 최적화!!! (0) | 2025.10.16 |
|---|---|
| 사용자 정의 조합, 연산들에 대해 Epilogue 확장 가이드 ( graph_capture-safe, epilogue condition 조건 만족) (0) | 2025.10.16 |
| sequential model 의 architecture 확인을 통한 fine kernels matching algorithm 구현 필요, (0) | 2025.10.13 |
| CUDA Graph 학습 경량 가이드 (0) | 2025.10.10 |
| gemm 헬퍼 모듈 요약 및 graph_capture 정리 (0) | 2025.10.10 |