본문 바로가기

dev_AI_framework

GEMM(+bias+act)에서 Z(pre-activation) 저장/활용 설계

GEMM 연산에서 pre-activation 의 값을 저장할 필요가 있음, 이는 역전파 시 불필요한 연산을 최소화하기 위함, 

 

Forward 경로

입력/출력

  • 입력 텐서:
    • A:[M,K], B:[K,N], Bias: scalar | (M,1)|(M,) | (1,N)|(N,) | None
  • 출력 텐서:
    • Y:[M,N] (post-activation)
    • Z_saved:[M,N] (옵션, pre-activation 저장용)
  • 제약: 모두 CUDA f32 row-major, 스트라이드 유효, M/K/N 및 Ld가 int32 범위.

연산 순서 (단일 패스)

  1. GEMM
    Z = A @ B
  2. bias 덧셈 (있다면)
    Z += bias
    • Broadcasting 규칙은 BiasKind에 따름: Scalar / PerM / PerN
  3. Z 저장 (save_z=True일 때)
    • Z_saved 버퍼로 pre-activation Z를 기록
    • Z_saved는 Y와 alias 가능(커널이 Z를 먼저 쓰고 이후 Y를 계산)
  4. activation 적용
    • Y = act(Z)
    • 지원: none | relu | leakyrelu | tanh | sigmoid | gelu(approx)
    • leaky_slope는 LeakyReLU용 파라미터

 

Backward 경로

입력/출력

  • 입력: A:[M,K], B:[K,N], gY:[M,N], Z:[M,N] (pre-activation), C(Optional)
  • 출력(옵션): gA:[M,K], gB:[K,N], gC:[M,N], gBias:[N] (Dense는 PerN 권장)
  • 제약: 모두 CUDA f32 row-major, 스트라이드/shape/int32 체크 동일.

핵심 수식

  1. activation 미분
    dZ = gY * act'(Z)
    • Z는 pre-activation이어야 정확 (ReLU 등에서 필요)
  2. GEMM 미분
    • gA = dZ @ B^T
    • gB = A^T @ dZ
  3. bias 미분
    • Scalar: sum(dZ)
    • PerM : sum(dZ, axis=1)
    • PerN : sum(dZ, axis=0) ← Dense 레이어는 보통 PerN(= (N,) 혹은 (1,N))
  4. C 경로(있다면)
    • gC += dZ (스케일/β 조건에 따라 누적)

 

Bias 처리 규약

호스트 측 레이아웃 (정리)

  • Scalar: float* 1개
  • PerM: float* M개 (예: (M,) 또는 (M,1))
  • PerN: float* N개 (예: (N,) 또는 (1,N))

런처의 kind 추론

  • (N,), (1,N) → PerN, (M,), (M,1) → PerM, (1,1) → Scalar
  • 위에 정확 매칭이 안 되더라도 numel 기반 보정: numel==N→PerN, numel==M→PerM, numel==1→Scalar
  • 커널 쪽 load_bias()는 p.bias_kind만 보고 인덱싱하므로, 런처가 올바른 kind를 세팅하는 게 핵심

save_z 관련 주의점

  • attrs.save_z=True면 반드시 Z_saved 버퍼를 넘겨야 함(바인딩에서 선제 체크).
  • Z_saved는 Y와 alias 가능(커널이 Z를 먼저 쓰도록 보장).
  • 저장되는 값은 항상 pre-activation.
    → 테스트에서 act!='none'이면 Z_saved ≈ Z(pre), Z_saved != Y(post)여야 정상.

실패/디버깅 체크리스트

  1. Y는 맞는데 bias가 빠진다
    • p.bias 포인터 유효?
    • p.bias_kind가 None으로 떨어지지 않았는지(추론 로직 확인)
  2. Z_saved가 post처럼 보인다
    • 커널 에필로그 순서 확인: GEMM → (+bias) → save(Z) → act → write(Y)
  3. 레이아웃/전치 오차
    • 현재는 no-transpose only. 전치 허용 전에는 테스트에서 best_match로 orientation 확인 권장.
  4. 역전파 오차
    • Z를 pre로 넘겼는지 확인. post를 넘기면 act'()가 틀어짐.

테스트 포인트(샘플 지표)

  • act='none', with_bias=False:
    Y ≈ A@B, Z_saved ≈ A@B
  • act='none', with_bias=True:
    Y ≈ A@B + bias, Z_saved ≈ A@B + bias
  • act in {relu,…}, with_bias=True:
    Z_saved ≈ A@B + bias, Y ≈ act(Z_saved), max_abs(Z_saved - Y) > eps

성능/안전 메모

  • 모든 차원/ld가 int32 범위 내여야 함.
  • Z_saved를 사용할 때 불필요한 추가 패스 없음(단일 패스 저장).
  • 스트림(void* → cudaStream_t) 전달 가능.