본문 바로가기

dev_AI_framework

gemm_bias_act_f32_tilled_kernel 만 자세히 봐보자

1. 해당 커널이 구현한 기본 패턴

크게 보면, C = alpha A B + beta C + bias -> activation 

을 M x N 타일로 쪼개고, K 축은 BK 씩 잘라가면서 A/B 의 K-슬라이스를 shared 에 올려서 여러 번 재사용하는 구조

 

2. 공간 쪼개기 : block / thread / thread-tile

2-1. block -> C 타일

const int m0 = blockIdx.y * BM_;
const int n0 = blockIdx.x * BN_;
  • blockIdx.y, blockIdx.x 가 각각 M 축, N 축 타일 인덱스
  • (m0 ,n0) = 이 블록이 담당하는 C 타일의 좌상단 global index

즉, 이 블록은

  • A 의 [m0 ... m0+BM_-1, :]
  • B 의 [:, n0 ... n0+BN_-1]
  • C/D 의 [m0 ..., n0 ...] 타일을 담당

 

2-2. block 내부에서 thread-> thread tile

const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tm0 = m0 + ty * THR_M;
const int tn0 = n0 + tx * THR_N;
  • 블록 내 스레드는TDX x TDY 로 배치
  • 스레드 (tx, ty) 가 맡는 C 영역
    • row : tm0 ... tm0 + THR_M -1
    • col : tn0 ... tn0 + THR_n -1

그래서 전체 매핑은

  • block 하나 BM X BN 타일
  • thread 하나 THR_M X THR_N 의 미니 타일
  • static_assert 로 보장
    • TDY * THR_M = BM
    • TDX * THR_N = BN

즉, TDX x TDY 개의 스레드가 서로 겹치지 않고 C 타일 전체를 딱 채운다

 

3. 메모리 계층 : global - shared - register

이 커널의 핵심은 

  1. global A/B 에서 K 축 슬라이스를 shared 로 올리고 
  2. 각스레드가 shared 에서 읽어서 레지스터에 누산
  3. 모든 K 블록이 끝나면 레지스터 acc 를 global D/C/bias/act 로 epilogue

이 흐름을 K 축 타일 루프 안에서 반복

 

4. A/B 타일 로딩 알고리즘

4-1. A 타일 로딩딩

auto load_A_tile = [&](int stage, int k0) {
  const int tid   = ty * TDX + tx;      // block 내 1D 스레드 인덱스
  const int elems = BM_ * BK_;         // A 타일 총 원소 수

  for (int e = tid; e < elems; e += (TDX * TDY)) {
    const int r = e / BK_;
    const int c = e % BK_;
    const int gm = m0 + r;
    const int gk = k0 + c;

    float v = 0.f;
    if (gm < p.M && gk < p.K) v = A[gm * p.lda + gk];
    As[stage][r][c] = v;
  }
};

해석

  • 블록 내 스레드가 1D tid 로 번호를 갖고, BM_*BK_ 개의 A 타일 요소를 균등하게 나눠 가져감
  • e 는 타일 내부의 편평 인덱스
    • r : 타일 내 row
    • c : 타일 내 col
  • global index
    • row : gm = m0 + r 
    • col : gk = k0 + c
  • 경계조건
    • gm 가 M 밖이거나 gk 가 K 밖이면 0 채움 
  • 최종 저장
    • As[stage][r][c] : shared memory 3D 배열
      • stage : 더블 버퍼링 스테이지
      • r : 0 ... BM_ - 1
      • c : 0 ... BK_  1

 

4-2 B 타일 로딩

auto load_B_tile = [&](int stage, int k0) {
  const int tid   = ty * TDX + tx;
  const int elems = BK_ * BN_;

  for (int e = tid; e < elems; e += (TDX * TDY)) {
    const int r = e / BN_;
    const int c = e % BN_;
    const int gk = k0 + r;
    const int gn = n0 + c;

    float v = 0.f;
    if (gk < p.K && gn < p.N) v = B[gk * p.ldb + gn];
    Bs[stage][r][c] = v;
  }
};
  • B 는 shape 이 BK_ x BN_ 이라 분자/분모만 바뀜
  • global index
    • row : gk = k0 + r
    • col : gn = n0 + c

 

4-3. 더블 버퍼링과 stage

int stage = 0;
if (p.K > 0) {
  load_A_tile(stage, 0);
  load_B_tile(stage, 0);
  __syncthreads();
}
  • 첫번째 K타일을 stage 0 에 로드 - compute 준비

K 루프 안

for (int k0 = 0; k0 < p.K; k0 += BK_) {
#if REGEMM_USE_DB
  const int next = stage ^ 1;
  if (k0 + BK_ < p.K) {
    load_A_tile(next, k0 + BK_);
    load_B_tile(next, k0 + BK_);
  }
#endif

  // (여기서 compute)

  __syncthreads();
#if REGEMM_USE_DB
  stage ^= 1;
#endif
}
  • 현재 stage 에 있는 타일을 가지고 compute
  • 동시에 다음 타일을 next stage 에 로드
  • 루프 끝에서 __syncthreads 로
    • 다음 stage 로드 완료 보장
    • stage 토글

 

5. compute(마이크로 커널) 알고리즘

K 타일 하나에 대해

#pragma unroll
for (int kk = 0; kk < BK_; ++kk) {
  float a_vec[THR_M];
  #pragma unroll
  for (int i = 0; i < THR_M; ++i) {
    const int rm = tm0 + i;
    a_vec[i] = (rm < p.M) ? As[stage][rm - m0][kk] : 0.f;
  }

  float b_vec[THR_N];
  #pragma unroll
  for (int j = 0; j < THR_N; ++j) {
    const int cn = tn0 + j;
    b_vec[j] = (cn < p.N) ? Bs[stage][kk][cn - n0] : 0.f;
  }

  #pragma unroll
  for (int i = 0; i < THR_M; ++i)
    #pragma unroll
    for (int j = 0; j < THR_N; ++j)
      acc[i][j] = fmaf(a_vec[i], b_vec[j], acc[i][j]);
}

 

5-1. kk loop = "rank-1 업데이트"

  • kk 는 현재 K 타일 내부의 local k index
  • As[stage][*, kk] : 현재 kk에 대한 A 열 벡터
  • Bs[stage][kk, *] : 현재 kk에 대한 B 행 벡터

이걸 생각하면 이 루프는

C_tile += A_tile[:, kk] * B_tile[kk, :] (outer product) 를 kk = 0...BK-1 까지 쌓는 것과 동일

 

5-2. thread 마이크로 커널의 구조

스레드 관점에서 보면

  • 내가 담당하는 row 블록에 대해
a_vec[i] = As[stage][row_i][kk];
  • 내가 담당하는 col 블록에 대해
b_vec[j] = Bs[stage][kk][col_j];
  • 그리고 모든 (i, j) 조합에 대해 FMA
acc[i][j] = fmaf(a_vec[i], b_vec[j], acc[i][j]);

재사용 관점

  • 이 스레드 안에서
    • a_vec[i] 하나당 THR_N 번 사용
    • b_vec[j] 하나당 THR_M 번 사용
  • 블록 전체 관점에서
    • shared 에 올라온 As[r][kk] 하나는,
      • 같은 row 를 담당하는 여러 thread 들이 나눠서 여러 D 요소에 사용.
    • Bs[kk][c] 도 마찬가지

이것이 tiled GEMM 에서 핵심인 데이터 재사용

  • global 에서 A/B 각 원소는 한 번만 읽고 
  • 그 값을 여러 FMA 에서 재사용 - 메모리 bandwidth 대비 FLOP 수를 극대화

 

6. 경계조건 처리 방식

타일링의 골칫덩이인 M, N, K 가 타일 크기 딱 안 ㅁ낮는 경우

  1. 로딩 단계에서 out-of-range -> 0 채우기
  2. compute 단계에서 출력 쓰기 시 m/n 체크

이렇게 나뉘어 있음

 

7. epilogue 는 타일 밖 계산이 아니라 레지스터 -> global 마지막 단계를 캡슐화