GEMM 연산에서 pre-activation 의 값을 저장할 필요가 있음, 이는 역전파 시 불필요한 연산을 최소화하기 위함,
Forward 경로
입력/출력
- 입력 텐서:
- A:[M,K], B:[K,N], Bias: scalar | (M,1)|(M,) | (1,N)|(N,) | None
- 출력 텐서:
- Y:[M,N] (post-activation)
- Z_saved:[M,N] (옵션, pre-activation 저장용)
- 제약: 모두 CUDA f32 row-major, 스트라이드 유효, M/K/N 및 Ld가 int32 범위.
연산 순서 (단일 패스)
- GEMM
Z = A @ B - bias 덧셈 (있다면)
Z += bias- Broadcasting 규칙은 BiasKind에 따름: Scalar / PerM / PerN
- Z 저장 (save_z=True일 때)
- Z_saved 버퍼로 pre-activation Z를 기록
- Z_saved는 Y와 alias 가능(커널이 Z를 먼저 쓰고 이후 Y를 계산)
- activation 적용
- Y = act(Z)
- 지원: none | relu | leakyrelu | tanh | sigmoid | gelu(approx)
- leaky_slope는 LeakyReLU용 파라미터
Backward 경로
입력/출력
- 입력: A:[M,K], B:[K,N], gY:[M,N], Z:[M,N] (pre-activation), C(Optional)
- 출력(옵션): gA:[M,K], gB:[K,N], gC:[M,N], gBias:[N] (Dense는 PerN 권장)
- 제약: 모두 CUDA f32 row-major, 스트라이드/shape/int32 체크 동일.
핵심 수식
- activation 미분
dZ = gY * act'(Z)- Z는 pre-activation이어야 정확 (ReLU 등에서 필요)
- GEMM 미분
- gA = dZ @ B^T
- gB = A^T @ dZ
- bias 미분
- Scalar: sum(dZ)
- PerM : sum(dZ, axis=1)
- PerN : sum(dZ, axis=0) ← Dense 레이어는 보통 PerN(= (N,) 혹은 (1,N))
- C 경로(있다면)
- gC += dZ (스케일/β 조건에 따라 누적)
Bias 처리 규약
호스트 측 레이아웃 (정리)
- Scalar: float* 1개
- PerM: float* M개 (예: (M,) 또는 (M,1))
- PerN: float* N개 (예: (N,) 또는 (1,N))
런처의 kind 추론
- (N,), (1,N) → PerN, (M,), (M,1) → PerM, (1,1) → Scalar
- 위에 정확 매칭이 안 되더라도 numel 기반 보정: numel==N→PerN, numel==M→PerM, numel==1→Scalar
- 커널 쪽 load_bias()는 p.bias_kind만 보고 인덱싱하므로, 런처가 올바른 kind를 세팅하는 게 핵심
save_z 관련 주의점
- attrs.save_z=True면 반드시 Z_saved 버퍼를 넘겨야 함(바인딩에서 선제 체크).
- Z_saved는 Y와 alias 가능(커널이 Z를 먼저 쓰도록 보장).
- 저장되는 값은 항상 pre-activation.
→ 테스트에서 act!='none'이면 Z_saved ≈ Z(pre), Z_saved != Y(post)여야 정상.
실패/디버깅 체크리스트
- Y는 맞는데 bias가 빠진다
- p.bias 포인터 유효?
- p.bias_kind가 None으로 떨어지지 않았는지(추론 로직 확인)
- Z_saved가 post처럼 보인다
- 커널 에필로그 순서 확인: GEMM → (+bias) → save(Z) → act → write(Y)
- 레이아웃/전치 오차
- 현재는 no-transpose only. 전치 허용 전에는 테스트에서 best_match로 orientation 확인 권장.
- 역전파 오차
- Z를 pre로 넘겼는지 확인. post를 넘기면 act'()가 틀어짐.
테스트 포인트(샘플 지표)
- act='none', with_bias=False:
Y ≈ A@B, Z_saved ≈ A@B - act='none', with_bias=True:
Y ≈ A@B + bias, Z_saved ≈ A@B + bias - act in {relu,…}, with_bias=True:
Z_saved ≈ A@B + bias, Y ≈ act(Z_saved), max_abs(Z_saved - Y) > eps
성능/안전 메모
- 모든 차원/ld가 int32 범위 내여야 함.
- Z_saved를 사용할 때 불필요한 추가 패스 없음(단일 패스 저장).
- 스트림(void* → cudaStream_t) 전달 가능.