본문 바로가기

dev_AI_framework

graph_capture - loss 까지 완료

🎯 현재 파이프라인 구조

 
CudaGraphTrainer.compile()
 ├─ make_plan_for_sequential()           # 각 layer의 입출력 버퍼/grad 버퍼 계획
 ├─ try_rebind_grads()                   # optimizer grad 포인터를 plan 버퍼와 연결
 ├─ record_step_graph(...)               # capture_stream 내부에서
 │    ├─ _run_fwd(model, plan, X_buf)
 │    ├─ loss_fn.forward(cur, y_buf) → loss_dev
 │    ├─ loss_out[...] = loss_dev        # ✅ 손실 스칼라를 디바이스 버퍼에 기록
 │    ├─ _zero_bwd_buffers(plan)
 │    ├─ _run_bwd(model, plan, dY)
 │    └─ optimizer_step_fn()             # (opt.step + zero_grad)
 │
 └─ gexec = cap.graph.instantiate()      # 그래프 객체 완성

CudaGraphTrainer.one_step()
 ├─ tg.set_batch(X, y)                   # 고정 버퍼에 복사
 ├─ tg.launch()                          # fwd→loss→bwd→opt 그래프 실행
 └─ loss = loss_buf.get()                # ✅ 그래프 내에서 갱신된 손실값을 D2H로 읽음

즉,

  • forward 패스, 손실 계산, backward, optimizer step, zero_grad
    전부 한 번의 CUDA Graph 캡처 스트림 안에서 실행됩니다.
  • loss_out 버퍼가 그래프 안에서 매 스텝 loss 값을 덮어쓰므로,
    호스트에서는 get() 한 번으로 D2H 읽기만 하면 됩니다.
  • CudaGraphTrainer.one_step()은 추가 fwd 연산이 전혀 없습니다.

🔍 검증 포인트

  • smoke 테스트 로그에서이런 식으로 손실이 자연스럽게 감소 → 그래프 내부에서 loss 계산/업데이트 모두 정상 동작.
  • 파라미터 업데이트(AdamWOpt) 후 loss_buf의 값이 바뀜 →
  • loss_out이 제대로 그래프 내에서 덮어쓰였다는 뜻입니다.
step=00 loss=1.685245
step=01 loss=1.665294
step=02 loss=1.642423

🚀 다음 단계 (선택적)

  1. GradScaler / AMP 통합
    → loss × scale, overflow 플래그도 같은 방식으로 버퍼에 포함시켜 캡처.
  2. 모델별 고정 입력으로 멀티-그래프 지원
    → 배치/시퀀스 길이가 다른 경우 TrainGraph 인스턴스를 여러 개 보관.
  3. NVTX 범위 태깅
    → loss, bwd, opt 범위별로 시각화.