본문 바로가기

dev_AI_framework

GEMM + Bias + Activation “1-Write” 설계 노트 (Register-Epilogue 방식)

목적: **행렬 곱(accumulate)**이 끝난 직후, 레지스터에 있는 결과(fragment/accumulator)에 바로 bias와 activation을 적용한 다음 글로벌 메모리로 단 한 번만 저장하는 커널을 만든다. (중간 글로벌 메모리 왕복 없음)

 

1) 문제 정의 & 수식

  • 표준 GEMM:
    C = A · B (A: M×K, B: K×N, C: M×N)
  • 우리가 구현할 형태(일반화): D = act(alpha * (A @ B) + Bias * beta * C_in)
  • 여기서 Act는 ReLU, GELU 등, Bias는 일반적으로 **per-column(N방향)**가 가장 흔함(FC 레이어). 필요하면 per-row, scalar 도 지원.
  • 핵심 제약:
    • Activation은 누산이 끝난 뒤 적용해야 한다(비선형이기 때문).
      ↳ 레지스터에 있는 accumulator(AB + optional β·C_in + bias)를 완성한 다음 in-register로 바로 activation → store.
    • Bias
      • (A) 누산 완료 후 한 번 더하는 방식, 또는
      • (B) acc 초기값에 bias를 미리 더해두고 FMA를 누적하는 방식(β·C_in을 함께 초기화해도 됨).
      • 어떤 방식이든 글로벌 메모리 추가 왕복 없이 레지스터에서 처리.

2) 목표 & 원칙

  • 한 번만 store: 중간 결과를 메모리에 쓰고 다시 읽지 않는다.
  • in-register epilogue:
    1. accumulate 완료 (frag/acc)
    2. (옵션) +β·C_in
    3. +bias
    4. activation
    5. vectorized store
  • 데이터 레이아웃: Row-major 기본. (필요 시 템플릿/파라미터로 Row/Col 선택 가능)
  • 형식: FP32(FMA), FP16/BF16(텐서코어/WMMA), TF32 옵션.
    • 누산은 가급적 FP32(정확도) 유지 후 cast.
  • 메모리 병목 최소화: A/B는 공유메모리 타일링, C_in/bias는 읽기 패턴을 coalesce, D는 vectorized store(float4/half2 등).

3) 커널 아키텍처 개요

3.1 타일링 계층

  • CTA tile (block): M_tile × N_tile
  • Warp tile: CTA 타일을 워프별로 분할
  • Thread tile / Fragment: 각 스레드(또는 wmma fragment)가 담당하는 소타일
  • K 루프: K를 K_step씩 나누어 공유메모리에 A_sub/B_sub 로드 → 누산

3.2 데이터 이동

  1. 글로벌 → 공유메모리: A_sub, B_sub (coalesced, 벡터로드)
  2. 공유메모리 → 레지스터: 매 step 마다 fragment 로드
  3. 누산(FMA/WMMA): acc_frag += a_frag × b_frag
  4. 에필로그(in-register):
    • (옵션) β·C_in 더하기: C_in은 글로벌에서 바로 레지스터로 로드 후 acc에 더하고 버림(별도 store 없음)
    • Bias 더하기: 일반적으로 N-방향 broadcast. warp 단위로 L1/L2 hit 유도
    • Activation: acc_frag = ReLU/GELU/…(acc_frag)
    • Cast & Vectorized store: (예: float → half2/float4) 글로벌에 최초이자 유일한 write

4) 인터페이스(안)

 
struct GemmBiasActParams {
  int M, N, K;            // sizes
  float alpha, beta;      // GEMM scaling
  const void* A; int lda; // row-major: lda = K
  const void* B; int ldb; // row-major: ldb = N
  const void* C; int ldc; // optional input (beta != 0)
  void* D;       int ldd; // output
  const void* bias;       // nullable; layout: per-column (N) by default
  int bias_kind;          // 0:none, 1:perN, 2:perM, 3:scalar
  int act;                // 0:none, 1:ReLU, 2:GELU_tanh, ...
  int dtype;              // 0:f32, 1:f16_acc_f32, 2:bf16_acc_f32, 3:tf32
  int use_tensor_cores;   // 0/1
};

5) 에필로그 설계(핵심)

  • 적용 순서 (권장):
    1. acc = alpha * acc
    2. if (beta != 0) acc += beta * C_in
    3. if (bias) acc += bias_broadcasted
    4. acc = activation(acc)
    5. store(acc) ← 글로벌 최초/유일 write
  • 왜 이 순서?
    • BLAS 호환(α,β) 유지
    • bias/activation은 최종 linear+nonlinear epilogue로 적용
    • 모든 단계 레지스터에서 끝냄
  • Bias broadcast:
    • per-N: 각 스레드가 담당하는 열 인덱스로 bias[n] 레지스터 캐시
    • per-M: 행 기준 broadcast (warp-wide shfl/broadcast 고려)
    • scalar: 단순 더하기
  • Activation 구현:
    • ReLU: x = max(x, 0)
    • GELU(tanh 근사): 0.5*x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
    • 기타(SiLU, LeakyReLU 등) 추가 가능—조건 분기 최소화, FMA 우선

6) 성능 포인트

  • 메모리 트래픽:
    • 표준 “분리형” 구현: C_temp = AB → (글로벌 write) → 다시 read → +bias → Act → write
    • 본 설계: AB 누산 완료 후 bias/Act 모두 레지스터에서 처리, D 한 번만 write
      → 글로벌 R/W 대폭 감소, L2 압박 완화
  • Vectorized IO:
    • load/store: float4 / half2 정렬(16B 정렬) 보장
    • bias는 L1/L2 힌트: per-N이면 n-stride 연속 접근으로 coalesce
  • 분기 최소화:
    • 템플릿/정적 파라미터로 act/bias 유무를 specialization → 분기 제거
    • 런타임 플래그도 버전 하나로 유지하고 싶다면 warp-uniform branch 유지
  • 형식 정책:
    • FP16/BF16 입력 + FP32 누산 + 최종 cast가 대세
    • 텐서코어(WMMA/cublasLt와 유사) 사용 시 fragment 크기(MMA tile) 맞춰 타일 설계

7) 커널 스켈레톤(요지)

 
template <typename T, typename AccT, typename ActOp, typename BiasPolicy>
__global__ void gemm_bias_act_kernel(GemmBiasActParams p) {
  // 1) 블록/스레드 타일 인덱스 계산
  const int m0 = blockIdx.y * M_TILE;
  const int n0 = blockIdx.x * N_TILE;

  // 2) 레지스터 accumulator 초기화 (AccT = float)
  AccT acc[THREAD_TILE_M][THREAD_TILE_N];
  #pragma unroll
  for (int i=0;i<THREAD_TILE_M;i++)
    #pragma unroll
    for (int j=0;j<THREAD_TILE_N;j++)
      acc[i][j] = AccT(0);

  // 3) K 루프: 공유메모리에 A_sub/B_sub 로드 → 레지스터로 FMA
  for (int k0 = 0; k0 < p.K; k0 += K_STEP) {
    // smem load A_sub, B_sub (coalesced)
    __syncthreads();
    // register fragment load & FMA
    // acc[i][j] += a_frag[i][kk] * b_frag[kk][j];
    __syncthreads();
  }

  // 4) (옵션) β·C_in 더하기 (in-register)
  if (p.beta != 0) {
    // C_in을 글로벌에서 레지스터로 직접 로드(벡터화)
    // acc[i][j] += p.beta * Cin_val;
  }

  // 5) α 스케일
  #pragma unroll
  for (int i=0;i<THREAD_TILE_M;i++)
    #pragma unroll
    for (int j=0;j<THREAD_TILE_N;j++)
      acc[i][j] *= AccT(p.alpha);

  // 6) Bias broadcast & add (레지스터)
  BiasPolicy::add(acc, p, m0, n0 /* thread tile 좌표 등 */);

  // 7) Activation (레지스터)
  #pragma unroll
  for (int i=0;i<THREAD_TILE_M;i++)
    #pragma unroll
    for (int j=0;j<THREAD_TILE_N;j++)
      acc[i][j] = ActOp::apply(acc[i][j]);

  // 8) Vectorized store (글로벌 최초/유일 write)
  // cast(AccT -> T) 후 coalesced store
}

 

WMMA(텐서코어) 버전은 nvcuda::wmma::fragment를 사용해 유사한 흐름으로 작성.
모든 에필로그는 fragment → 레지스터 상에서 처리 후 store_matrix_sync 또는 커스텀 vectorized store.


8) 정확도/수치 이슈

  • 누산은 FP32 권장 (특히 FP16/BF16 입력 시)
  • GELU 등 비선형은 FP32로 계산 후 cast
  • Denorm/Inf/NAN 방지:
    • 입력·bias에 대한 범위 점검(옵션), fast-math 사용 시 검토
    • tanh 근사 상수(0.044715)는 FMA 결합으로 구현

9) NVTX 주석(프로파일링)

  • 범위(tag):
    • "load_A/B", "mma_kloop", "beta_C_add", "bias_add", "activation", "store_D"
  • 블록/워프 식별용 payload로 (m0, n0) 주입 → 타일별 병목 시각화

10) 기능/테스트 체크리스트

  • M,N,K 다양한 배수/잔여 처리(경계 스칼라 경로)
  • dtype 조합: f32, f16/bf16 in + f32 acc + cast
  • bias kind: none/perN/perM/scalar
  • activation: none/ReLU/GELU(tanh) (unit tests with tolerance)
  • α,β 케이스: (α=1,β=0), (α≠1), (β≠0)
  • 레이아웃: Row-major 기본, 필요 시 Col 지원
  • 성능: 분리형 대비 글로벌 R/W 감소 확인, cuBLASLt epilogue 대비 baseline 비교
  • NVTX/NSight Compute 보고서 템플릿

11) cuBLASLt와의 관계

  • cuBLASLt의 epilogue(bias/activation 등)는 우리 목표와 개념적으로 동일:
    “accumulator 완료 → in-register epilogue → single store”
  • 차별점: 우리는 커스텀 커널로 내부 타일링/로드/에필로그를 통제.
    • 특수한 bias 형태(예: per-row + per-channel), 사용자 정의 activation, 추가 fusion(예: residual add) 등 확장 용이.
    • 반면, 유지보수/튜닝 비용은 증가.

12) 확장 아이디어

  • Residual add: acc += residual (in-register) → activation
  • Clamp/Quantize epilogue: int8/FP8 경로 준비
  • Dropout(훈련시): RNG tile-wise → mask를 레지스터에서 적용(단, 결정적 실행/성능 trade-off 고려)

13) 미니 코드 조각 (ReLU, per-N bias)

struct ActReLU {
  __device__ static inline float apply(float x) { return x > 0.f ? x : 0.f; }
};

struct BiasPerN {
  __device__ static inline void add(float acc[][THREAD_TILE_N],
                                    const GemmBiasActParams& p,
                                    int m0, int n0 /* tile origin */) {
    if (!p.bias) return;
    const float* bias = reinterpret_cast<const float*>(p.bias);
    #pragma unroll
    for (int j=0;j<THREAD_TILE_N;j++) {
      int n = n0 + /* thread local offset */ j;
      float b = (n < p.N) ? bias[n] : 0.f;
      #pragma unroll
      for (int i=0;i<THREAD_TILE_M;i++) acc[i][j] += b;
    }
  }
};
 

14) 결론

  • 우리가 하려는 것은 “진짜” in-register epilogue:
    누산 완료 직후 레지스터에서 bias/activation을 모두 처리하고, 글로벌 메모리에는 단 한 번만 저장한다.
  • Activation은 누산 중간에 끼워 넣을 수 없고(비선형), 누산 끝난 시점에 적용하는 것이 올바른/가장 빠른 경로다.
  • 이 문서를 바탕으로 먼저 **FP32 FMA 경로(비WMMA)**로 스모크 → FP16/BF16-acc-FP32 → WMMA 확장 순서로 진행을 권장.