1. 해당 커널이 구현한 기본 패턴
크게 보면, C = alpha A B + beta C + bias -> activation
을 M x N 타일로 쪼개고, K 축은 BK 씩 잘라가면서 A/B 의 K-슬라이스를 shared 에 올려서 여러 번 재사용하는 구조
2. 공간 쪼개기 : block / thread / thread-tile
2-1. block -> C 타일
const int m0 = blockIdx.y * BM_;
const int n0 = blockIdx.x * BN_;
- blockIdx.y, blockIdx.x 가 각각 M 축, N 축 타일 인덱스
- (m0 ,n0) = 이 블록이 담당하는 C 타일의 좌상단 global index
즉, 이 블록은
- A 의 [m0 ... m0+BM_-1, :]
- B 의 [:, n0 ... n0+BN_-1]
- C/D 의 [m0 ..., n0 ...] 타일을 담당
2-2. block 내부에서 thread-> thread tile
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tm0 = m0 + ty * THR_M;
const int tn0 = n0 + tx * THR_N;
- 블록 내 스레드는TDX x TDY 로 배치
- 스레드 (tx, ty) 가 맡는 C 영역
- row : tm0 ... tm0 + THR_M -1
- col : tn0 ... tn0 + THR_n -1
그래서 전체 매핑은
- block 하나 BM X BN 타일
- thread 하나 THR_M X THR_N 의 미니 타일
- static_assert 로 보장
- TDY * THR_M = BM
- TDX * THR_N = BN
즉, TDX x TDY 개의 스레드가 서로 겹치지 않고 C 타일 전체를 딱 채운다
3. 메모리 계층 : global - shared - register
이 커널의 핵심은
- global A/B 에서 K 축 슬라이스를 shared 로 올리고
- 각스레드가 shared 에서 읽어서 레지스터에 누산
- 모든 K 블록이 끝나면 레지스터 acc 를 global D/C/bias/act 로 epilogue
이 흐름을 K 축 타일 루프 안에서 반복
4. A/B 타일 로딩 알고리즘
4-1. A 타일 로딩딩
auto load_A_tile = [&](int stage, int k0) {
const int tid = ty * TDX + tx; // block 내 1D 스레드 인덱스
const int elems = BM_ * BK_; // A 타일 총 원소 수
for (int e = tid; e < elems; e += (TDX * TDY)) {
const int r = e / BK_;
const int c = e % BK_;
const int gm = m0 + r;
const int gk = k0 + c;
float v = 0.f;
if (gm < p.M && gk < p.K) v = A[gm * p.lda + gk];
As[stage][r][c] = v;
}
};
해석
- 블록 내 스레드가 1D tid 로 번호를 갖고, BM_*BK_ 개의 A 타일 요소를 균등하게 나눠 가져감
- e 는 타일 내부의 편평 인덱스
- r : 타일 내 row
- c : 타일 내 col
- global index
- row : gm = m0 + r
- col : gk = k0 + c
- 경계조건
- gm 가 M 밖이거나 gk 가 K 밖이면 0 채움
- 최종 저장
- As[stage][r][c] : shared memory 3D 배열
- stage : 더블 버퍼링 스테이지
- r : 0 ... BM_ - 1
- c : 0 ... BK_ 1
- As[stage][r][c] : shared memory 3D 배열
4-2 B 타일 로딩
auto load_B_tile = [&](int stage, int k0) {
const int tid = ty * TDX + tx;
const int elems = BK_ * BN_;
for (int e = tid; e < elems; e += (TDX * TDY)) {
const int r = e / BN_;
const int c = e % BN_;
const int gk = k0 + r;
const int gn = n0 + c;
float v = 0.f;
if (gk < p.K && gn < p.N) v = B[gk * p.ldb + gn];
Bs[stage][r][c] = v;
}
};
- B 는 shape 이 BK_ x BN_ 이라 분자/분모만 바뀜
- global index
- row : gk = k0 + r
- col : gn = n0 + c
4-3. 더블 버퍼링과 stage
int stage = 0;
if (p.K > 0) {
load_A_tile(stage, 0);
load_B_tile(stage, 0);
__syncthreads();
}
- 첫번째 K타일을 stage 0 에 로드 - compute 준비
K 루프 안
for (int k0 = 0; k0 < p.K; k0 += BK_) {
#if REGEMM_USE_DB
const int next = stage ^ 1;
if (k0 + BK_ < p.K) {
load_A_tile(next, k0 + BK_);
load_B_tile(next, k0 + BK_);
}
#endif
// (여기서 compute)
__syncthreads();
#if REGEMM_USE_DB
stage ^= 1;
#endif
}
- 현재 stage 에 있는 타일을 가지고 compute
- 동시에 다음 타일을 next stage 에 로드
- 루프 끝에서 __syncthreads 로
- 다음 stage 로드 완료 보장
- stage 토글
5. compute(마이크로 커널) 알고리즘
K 타일 하나에 대해
#pragma unroll
for (int kk = 0; kk < BK_; ++kk) {
float a_vec[THR_M];
#pragma unroll
for (int i = 0; i < THR_M; ++i) {
const int rm = tm0 + i;
a_vec[i] = (rm < p.M) ? As[stage][rm - m0][kk] : 0.f;
}
float b_vec[THR_N];
#pragma unroll
for (int j = 0; j < THR_N; ++j) {
const int cn = tn0 + j;
b_vec[j] = (cn < p.N) ? Bs[stage][kk][cn - n0] : 0.f;
}
#pragma unroll
for (int i = 0; i < THR_M; ++i)
#pragma unroll
for (int j = 0; j < THR_N; ++j)
acc[i][j] = fmaf(a_vec[i], b_vec[j], acc[i][j]);
}
5-1. kk loop = "rank-1 업데이트"
- kk 는 현재 K 타일 내부의 local k index
- As[stage][*, kk] : 현재 kk에 대한 A 열 벡터
- Bs[stage][kk, *] : 현재 kk에 대한 B 행 벡터
이걸 생각하면 이 루프는
C_tile += A_tile[:, kk] * B_tile[kk, :] (outer product) 를 kk = 0...BK-1 까지 쌓는 것과 동일
5-2. thread 마이크로 커널의 구조
스레드 관점에서 보면
- 내가 담당하는 row 블록에 대해
a_vec[i] = As[stage][row_i][kk];
- 내가 담당하는 col 블록에 대해
b_vec[j] = Bs[stage][kk][col_j];
- 그리고 모든 (i, j) 조합에 대해 FMA
acc[i][j] = fmaf(a_vec[i], b_vec[j], acc[i][j]);
재사용 관점
- 이 스레드 안에서
- a_vec[i] 하나당 THR_N 번 사용
- b_vec[j] 하나당 THR_M 번 사용
- 블록 전체 관점에서
- shared 에 올라온 As[r][kk] 하나는,
- 같은 row 를 담당하는 여러 thread 들이 나눠서 여러 D 요소에 사용.
- Bs[kk][c] 도 마찬가지
- shared 에 올라온 As[r][kk] 하나는,
이것이 tiled GEMM 에서 핵심인 데이터 재사용
- global 에서 A/B 각 원소는 한 번만 읽고
- 그 값을 여러 FMA 에서 재사용 - 메모리 bandwidth 대비 FLOP 수를 극대화
6. 경계조건 처리 방식
타일링의 골칫덩이인 M, N, K 가 타일 크기 딱 안 ㅁ낮는 경우
- 로딩 단계에서 out-of-range -> 0 채우기
- compute 단계에서 출력 쓰기 시 m/n 체크
이렇게 나뉘어 있음
7. epilogue 는 타일 밖 계산이 아니라 레지스터 -> global 마지막 단계를 캡슐화
'dev_AI_framework' 카테고리의 다른 글
| 어느 환경에서 학습 시 매 번 새로운 graph 가 생성될까 (0) | 2025.12.01 |
|---|---|
| 상용 딥러닝 프레임 워크는 Inference Runtime 기반 구조이다...! (0) | 2025.12.01 |
| 현재 구현된 gemm 의 fwd 부분 커널 코드 확인 (0) | 2025.11.24 |
| Shared Memory - Bank Conflict (0) | 2025.11.23 |
| 실험적 test 코드 작성 - Thread / Block / Grid 인덱싱 감각 잡기 (0) | 2025.11.16 |