본문 바로가기

AI Compiler framework

최소 학습을 위한 추가 구현 필요 내용 ( MLP 를 목표로, (Linaer - ReLU + Linear) + Loss + SGD

forward 에 이은 학습까지 필요, MLP 를 목표로 forward / backward / update 로 쪼갤 수 있음

Forward

  • GEMM
  • AddBias
  • ReLU
  • GEMM
  • AddBias
  • Loss

Backward

  • dLoss / dY 
  • GEMM for dW
  • GEMM for dX
  • bias grad : reduce_sum
  • ReLU backward : dX *= mask

Update

  • SGD : W -= lr dW, b -= lr db

 

역전파의 핵심, GEMM 의 변형 필요

GEMM 은 forward 용 A B 임, backward 의 경우 추가 2개가 더 필요

  • GEMM_TN 
  • GEMM_NT

 

반드시 필요한 Elementwise / Reduction

  • ReLU backward ( maks 곱 )
    • forward 에서 mask 를 저장하지 않으면, backward 에서 input 또는 output 기준으로 재계산 가능
  • reduce_sum / bias_grad
    • db = sum, 이게 없으면 bias 학습이 막함
  • Loss
    • MSE loss + backward
  • Update
    • SGD update
      • W += alpha * dW
      • bias 도 동일
      • add 가 있으면scaled add 가 필요

 

 

학습 1-step 성공 체크리스트

end-to-end correctness : 결과 비교

capture-safe / deterministic 확인