forward 에 이은 학습까지 필요, MLP 를 목표로 forward / backward / update 로 쪼갤 수 있음
Forward
- GEMM
- AddBias
- ReLU
- GEMM
- AddBias
- Loss
Backward
- dLoss / dY
- GEMM for dW
- GEMM for dX
- bias grad : reduce_sum
- ReLU backward : dX *= mask
Update
- SGD : W -= lr dW, b -= lr db
역전파의 핵심, GEMM 의 변형 필요
GEMM 은 forward 용 A B 임, backward 의 경우 추가 2개가 더 필요
- GEMM_TN
- GEMM_NT
반드시 필요한 Elementwise / Reduction
- ReLU backward ( maks 곱 )
- forward 에서 mask 를 저장하지 않으면, backward 에서 input 또는 output 기준으로 재계산 가능
- reduce_sum / bias_grad
- db = sum, 이게 없으면 bias 학습이 막함
- Loss
- MSE loss + backward
- Update
- SGD update
- W += alpha * dW
- bias 도 동일
- add 가 있으면scaled add 가 필요
- SGD update
학습 1-step 성공 체크리스트
end-to-end correctness : 결과 비교
capture-safe / deterministic 확인