본문 바로가기

AI Compiler framework

PR3: Torch Golden 과 1-Step Training 일치까지의 전체 과정 정리

목표 : AICF Framework 에서

  • forward, 
  • backward
  • optimizer step

이 한 스텝 학습 결과가 PyTroch 와 수치적으로 완전히 동일함을 보장F Framework 에서

 

forward, 

backward

optimizer step

이 한 스텝 학습 결과가 PyTroch 와 수치적으로 완전히 동일함을 보장

검증 기준

  • 모든 intermediate op 는 out-buffer 기반
  • CUDA Graph capture-safe 설계
  • Torch golden 과 allclose 통과

 

전체 스텝 개요

PR1  GEMM correctness (transA/transB)
 ↓
PR2  Autograd + basic backward path
 ↓
PR3  ReLU + Linear + Adam 포함 1-step training
     (torch golden 완전 일치)

 

Step 1. GEMM stride / transpose 통합 (PR1)

문제

  •  transA, transB 별로 커널 분기하면
    • variant 폭증
    • backward 에서 조합 관리 어려움

해결

  • Logical View 기반 GEMM
    • MatView2D
    • transpose 는 stride swap 으로 표현
  • f32
    • 단일 naive strided kernel
  • f16
    • WMMA kernel
    • GLOBAL - SMEM packing 시 transpose-at-load

검증

[OK] op:gemm transA sanity

GEMM 은 이제 forward / backward 공통 기반

 

Step 2. MSE Grad 를 명시적으로 분리

이유

  • loss backward 를 autograd graph 에 넣으면
    • capture-safe 보장 어려움
    • scalar grad 생성이 불안정

결정

  • mse_grad 를 명시적 op 로 분리
  • output 은 항상 out-buffer
  • scale = 2 / numel

검증

[OK] op:mse_grad matches torch golden

loss gradient 는 항상 명시적 공급

 

 Step 3. Autograd 설계 원칙 확정

핵심 정책

  • Leaf gradient 만 persistent
    • parameter.grad 만 유지
    • 중간 tensor grad 는 gmap 에 임시 저장
  • overwrite vs accumulate 분리
    • accumulate = False 
      • copy() 로 leaf grad overwrite
    • accumulate = True
      • add - 새 텐서 생성 가능
  • Capture - safe 강제 규칙
    • capture 중
      • grad = None 금지
      • tensor allocation 금지
      • accumulate = True 금지

이게 지금 autograd 코드의 정체성

 

Step 4. Linear backward 수식 검증

버그, linear 0.W gradient 폭발

ReLU backward 에서 mask 가 깨짐

  • forward y 버퍼를 pool 에서 재사용
  • 그 y 를 그대로 relu_bwd 에 사용
  • 이미 다른ㄹ op 가 y 를 overwrite

해결

  • forward 출력 y_buf 는 downstream 용
  • backward 용 y_saved 는 재사용 금지

 

최종 검증 결과

[OK] relu_bwd sanity
[OK] gemm sanity
[OK] mse_grad sanity

[OK] train1:grad 0.W
[OK] train1:grad 0.b
[OK] train1:grad 2.W
[OK] train1:grad 2.b

[OK] train1:param 0.W
[OK] train1:param 0.b
[OK] train1:param 2.W
[OK] train1:param 2.b

[OK] one-step train correctness matches torch golden
  • AICF 의 autograd 설계는 수식적으로 torch 와 동일
  • CUDA Graph capture-safe 정책이 실제로 동작
  • out-buffer 기반 연산 + buffer pool 전략이 유효
  • optimizer / forwrad / backward 책임 분리가 정확