본문 바로가기

dev_AI_framework

현재 구현된 gemm 의 fwd 부분 커널 코드 확인

1. 타일 / 스레드 구조 정의

constexpr int BM  = REGEMM_TILE_M;
constexpr int BN  = REGEMM_TILE_N;
constexpr int BK  = REGEMM_TILE_K;

constexpr int TDX = REGEMM_BLOCK_TDX;
constexpr int TDY = REGEMM_BLOCK_TDY;

constexpr int THR_M = REGEMM_THREAD_TILE_M;
constexpr int THR_N = REGEMM_THREAD_TILE_N;
  • BM, BN, BK : block 하나가 맡는 타일 크기
  • TDX, TDY : 블록 내 스레드 배치 ( threadidx.x / y 최대값 )
  • THR_M, THR_N : 스레드 하나가 담당는 C 타일 내 미니 타일 크기 
    • 한 스레드는 THR_M X THR_N 개의 (m, n) 출력 요소를 로컬 레지스터에 accum

 

2. (1) Smoke 커널 - 작은 문제용

__global__ void gemm_bias_act_f32_smoke(GemmBiasActParams p) {
  const int m = blockIdx.y * blockDim.y + threadIdx.y;
  const int n = blockIdx.x * blockDim.x + threadIdx.x;
  if (m >= p.M || n >= p.N) return;
  • 전형ㅈ억인 naive GEMM 인덱싱
    • m : 전역 row index
    • n : 전역 col index
  • 범위 밖의 retrun 으로 결계 조건 처리
  const float* __restrict__ A = reinterpret_cast<const float*>(p.A);
  const float* __restrict__ B = reinterpret_cast<const float*>(p.B);
  const float* __restrict__ C = reinterpret_cast<const float*>(p.C);
  float* __restrict__       D = reinterpret_cast<float*>(p.D);
  • params 에 들어있는 void* 나 generic ptr 을 f32 로 캐스팅
  • __restrict__ 로 aliasing 없다고 힌트 -> 컴파일러 최적화 도움

 

  float acc = 0.f;
  for (int k = 0; k < p.K; ++k) {
    float a = A[m * p.lda + k];
    float b = B[k * p.ldb + n];
    acc = fmaf(a, b, acc);
  }
  • 한 스레드가 C[m,n] 하나를 담당
  • row-majore A, B
  • fmaf 로 FMA 사용

 

  float pre = p.alpha * acc;
  if (p.beta != 0.f && C) pre = fmaf(p.beta, C[m * p.ldc + n], pre);
  pre += load_bias(p, m, n);

  D[m * p.ldd + n] = apply_act_runtime(pre, p.act, p.leaky_slope);
}
  • pre : epilogue 이전 "pre-activaion" 값
    • alpha * acc
    • beta * C add (옵션)
    • bias add
  • activation 은 런타임스위치, ActKind 에 따라 ReLU, GELU 등의 적용

 

static inline void launch_gemm_bias_act_f32_smoke(const GemmBiasActParams& p, cudaStream_t s) {
  dim3 block(16, 16);
  dim3 grid((p.N + block.x - 1) / block.x, (p.M + block.y - 1) / block.y);
  AI_NVTX_RANGE("regemm.smoke", nvtx::Color::Gray);
  gemm_bias_act_f32_smoke<<<grid, block, 0, s>>>(p);
}
  • 16 X 16 블록, 
  • 작은 문제에서 tile 보다 더 나은 성능 보이도록

 

3. (2) Tiled FWD 커널

3-1. 블록/스레드 인덱싱

template<int BM_, int BN_, int BK_, ActKind AK, BiasMode BMmode, bool HasC>
__global__ void gemm_bias_act_f32_tiled_kernel(GemmBiasActParams p) {
  const int m0 = blockIdx.y * BM_;
  const int n0 = blockIdx.x * BN_;
  const int tx = threadIdx.x;
  const int ty = threadIdx.y;
  const int tm0 = m0 + ty * THR_M;
  const int tn0 = n0 + tx * THR_N;
  • 블록이 담당하는 C 타일릐 좌상단 좌표 (m0, n0)
  • 블록 내 스레드 좌표 (tx, ty)에 따라, 각 스레드가 맡는 "시작 row/col"
    • 스레드 (tx, ty) 는 (tm0 ... tm0+THR_M-1, tn0 ... tn0+THR_N-1) 범위의 C 요소 담당

 

3-2 shared memory 버퍼

#if REGEMM_USE_DB
  __shared__ float As[2][BM_][BK_ + PADK];
  __shared__ float Bs[2][BK_][BN_ + PADN];
#else
  __shared__ float As[1][BM_][BK_ + PADK];
  __shared__ float Bs[1][BK_][BN_ + PADN];
#endif
  • A 타일 : BM_ X BK_ / B 타일 : BK_ X BN_
  • 더블 버퍼링 옵션
    • As[2][...] / Bs[2][...] : 두 스테이지를 번갈아 쓰면서 load-compute overlap 시도
    • stage 변수로 현재/다음 스테이지 선택
  • PADK, PADN 만큼 열 패딩 - > bank conflict 완화

 

3-3 레지스터 accum 초기화

  float acc[THR_M][THR_N];
  #pragma unroll
  for (int i = 0; i < THR_M; ++i)
    #pragma unroll
    for (int j = 0; j < THR_N; ++j)
      acc[i][j] = 0.f;
  • 각 스레드의 로컬 output 블록 : THR_M X THR_N 개
  • 모두 0으로 초기화

 

3-4 전역 포인터 캐스팅

  const float* __restrict__ A = reinterpret_cast<const float*>(p.A);
  const float* __restrict__ B = reinterpret_cast<const float*>(p.B);
  const float* __restrict__ C = reinterpret_cast<const float*>(p.C);
  float* __restrict__       D = reinterpret_cast<float*>(p.D);

 

3-5 A/B 타일 로더 람다

  auto load_A_tile = [&](int stage, int k0) {
    const int tid   = ty * TDX + tx;
    const int elems = BM_ * BK_;
    for (int e = tid; e < elems; e += (TDX * TDY)) {
      const int r = e / BK_;
      const int c = e % BK_;
      const int gm = m0 + r;
      const int gk = k0 + c;
      float v = 0.f;
      if (gm < p.M && gk < p.K) v = A[gm * p.lda + gk];
      As[stage][r][c] = v;
    }
  };
  • 블록 내 모든 스레드를 1D tid 로 펴서, BM_ X BK_ 개의 A 타일 요소를 나눠 로드,
  • e 는 해당 타일 내 평면 인덱스 
    • r = e / BK_
    • c = e % BK_
  • 글로벌 좌표
    • gm = m0 + r
    • gk = k0 + c
  • 경계 밖이면 0 패딩 

 

  auto load_B_tile = [&](int stage, int k0) {
    const int tid   = ty * TDX + tx;
    const int elems = BK_ * BN_;
    for (int e = tid; e < elems; e += (TDX * TDY)) {
      const int r = e / BN_;
      const int c = e % BN_;
      const int gk = k0 + r;
      const int gn = n0 + c;
      float v = 0.f;
      if (gk < p.K && gn < p.N) v = B[gk * p.ldb + gn];
      Bs[stage][r][c] = v;
    }
  };
  • B 타일도 동일 패턴, BK_ X BN_ 형태에 맞게 변경

 

3-6 초기 스테이지 로드 + 동기화

  int stage = 0;
  if (p.K > 0) {
    load_A_tile(stage, 0);
    load_B_tile(stage, 0);
    __syncthreads();
  }
  • k0 = 0 인 첫 K 타일을 shared 로 올리고, 모든 스레드 동기화

 

3-7 K 타일 루프 + 더블 버퍼링

  for (int k0 = 0; k0 < p.K; k0 += BK_) {
#if REGEMM_USE_DB
    const int next = stage ^ 1;
    if (k0 + BK_ < p.K) {
      load_A_tile(next, k0 + BK_);
      load_B_tile(next, k0 + BK_);
    }
#endif
  • 현재 stage 의 타일을 compute 하는 동안, 다음 K 타일을 다른 stage 에 미리 로드
  • SM 레벨에서 어느 정도 overlap 기대

 

    #pragma unroll
    for (int kk = 0; kk < BK_; ++kk) {
      float a_vec[THR_M];
      #pragma unroll
      for (int i = 0; i < THR_M; ++i) {
        const int rm = tm0 + i;
        a_vec[i] = (rm < p.M) ? As[stage][rm - m0][kk] : 0.f;
      }
  • 현재 K 타일 안에서 kk(로컬 k index) 를 돌면서 micro-kernel 수행
  • a_vec[i]
    • 스레드가 담당하는 THR_M 개의 row 에 대해, 해당 kk 칼럼의 값을 벡터로 뽑음
    • rm - m0 로 타일 내부 row index 변환

 

      float b_vec[THR_N];
      #pragma unroll
      for (int j = 0; j < THR_N; ++j) {
        const int cn = tn0 + j;
        b_vec[j] = (cn < p.N) ? Bs[stage][kk][cn - n0] : 0.f;
      }
  • b_vec[j]
    • 스레드가 담당하는 THR_N 개의 col 에 대해, kk row 의 값을 벡터로 뽑음

 

      #pragma unroll
      for (int i = 0; i < THR_M; ++i)
        #pragma unroll
        for (int j = 0; j < THR_N; ++j)
          acc[i][j] = fmaf(a_vec[i], b_vec[j], acc[i][j]);
    }
  • 진짜 micro-kernel
    • THR_M X THR_N 레지스터 블록에 대해 outer-product accumulate

 

    __syncthreads();
#if REGEMM_USE_DB
    stage ^= 1;
#endif
  }
  • 한 K 타일 끝나면 동기화, 다음 루프 iteratino 에서 stage 토글

 

3-8 PerN bias 프리패치

  float bias_j[THR_N];
  #pragma unroll
  for (int j = 0; j < THR_N; ++j) {
    const int n = tn0 + j;
    bias_j[j] = (n < p.N) ? load_bias(p, 0, n) : 0.f;
  }
  • BiasMode == PerN 인 경우를 위해, col 방향 bias 를 미리 캐싱
  • load_bias(p, 0, n) 에서 m 은 의미없는 자리로 쓰고 있을 가능성이 큼

 

  using EP = Epilogue<AK, BMmode, HasC, /*SaveZ*/false>;
  const int ldc = p.ldc, ldd = p.ldd;
  • compile-time 템플릿 기반 epilogue 정책
    • AK : activation compile-time 선택
    • BMmode : bias 모드 compile-time
    • HasC : C add 여부
    • SaveZ-false : pre-activation  저장 안 함

 

3-9 최종 쓰기 루프

  #pragma unroll
  for (int i = 0; i < THR_M; ++i) {
    const int m = tm0 + i;
    if (m >= p.M) continue;

    float bias_m_cached = 0.f;
    if constexpr (BMmode == BiasMode::PerM) {
      bias_m_cached = load_bias(p, m, 0);
    }
  • 스레드의 각 row 에 대해
    • out-of-range 면 continue
    • perM bias 인 경우, row-wise bias 를 한 번만 로드해서 캐시 

 

    #pragma unroll
    for (int j = 0; j < THR_N; ++j) {
      const int n = tn0 + j;
      if (n < p.N) {
        EP::apply(
          /*D*/ reinterpret_cast<float*>(p.D), ldd,
          /*C*/ reinterpret_cast<const float*>(p.C), ldc,
          /*Z*/ nullptr, 0,
          /*p*/ p,
          /*m*/ m, /*n*/ n,
          /*acc*/ acc[i][j],
          /*bias_j*/ (BMmode == BiasMode::PerN ? bias_j[j] : 0.f),
          /*bias_m*/ bias_m_cached
        );
      }
    }
  }
}
  • Epilogue 정책 객체에 모든 정보 전달
    • (m, n) 위치
    • acc 값
    • PerN / PerM bias 값
    • C pointer & beta, alpha, activation 등은 p 안에
  • EP::apply 안에서 alpba acc + Beta C + bias + activation 까지 완전히 처리

 

4. 런처들

4-1. FWD tiled 런처

template<ActKind AK, BiasMode BMmode, bool HasC>
static inline void launch_fwd_cfg(const GemmBiasActParams& p, cudaStream_t s) {
  dim3 block(TDX, TDY);
  dim3 grid((p.N + BN - 1) / BN, (p.M + BM - 1) / BM);
  AI_NVTX_RANGE("regemm.tiled", nvtx::Color::Blue);
  gemm_bias_act_f32_tiled_kernel<BM, BN, BK, AK, BMmode, HasC><<<grid, block, 0, s>>>(p);
}
  • 템플릿에서 activation/bias/HasC 를 compile-time 으로 확정한 버전
  • 그 위에서 BiasMode 를 런타임에서 분기

 

template<ActKind AK, bool HasC>
static inline void launch_fwd_cfg_bm(const GemmBiasActParams& p, BiasMode bm, cudaStream_t s) {
  switch (bm) {
    case BiasMode::PerM: launch_fwd_cfg<AK, BiasMode::PerM, HasC>(p, s); break;
    case BiasMode::PerN: launch_fwd_cfg<AK, BiasMode::PerN, HasC>(p, s); break;
    case BiasMode::Full: launch_fwd_cfg<AK, BiasMode::Full, HasC>(p, s); break;
    case BiasMode::None:
    default:             launch_fwd_cfg<AK, BiasMode::None, HasC>(p, s); break;
  }
}
  • biasMode 는 런타임 enum - 스위치 - 각 case 마다 템플릿 인스턴스 선택

 

static inline void launch_gemm_bias_act_f32_tiled(const GemmBiasActParams& p, cudaStream_t s) {
  const BiasMode bm = to_bias_mode(p.bias_kind);
  const bool hasC = (p.beta != 0.f && p.C != nullptr);

  switch (p.act) {
    case ActKind::ReLU:
      if (hasC) launch_fwd_cfg_bm<ActKind::ReLU,      true >(p, bm, s);
      else      launch_fwd_cfg_bm<ActKind::ReLU,      false>(p, bm, s);
      break;
    ...
    case ActKind::None:
    default:
      if (hasC) launch_fwd_cfg_bm<ActKind::None,      true >(p, bm, s);
      else      launch_fwd_cfg_bm<ActKind::None,      false>(p, bm, s);
      break;
  }
}
  • ActKind 는 런타임 enum - 스위치 - 템플릿 ActKind 인자로 고정
  • C add 여부는 bool hasC 로 결정 - HasC 템플릿에 반영
  • 이렇게 해서 핫패스에서는 완전한 compile-time epilogue specialization