1. 타일 / 스레드 구조 정의
constexpr int BM = REGEMM_TILE_M;
constexpr int BN = REGEMM_TILE_N;
constexpr int BK = REGEMM_TILE_K;
constexpr int TDX = REGEMM_BLOCK_TDX;
constexpr int TDY = REGEMM_BLOCK_TDY;
constexpr int THR_M = REGEMM_THREAD_TILE_M;
constexpr int THR_N = REGEMM_THREAD_TILE_N;
- BM, BN, BK : block 하나가 맡는 타일 크기
- TDX, TDY : 블록 내 스레드 배치 ( threadidx.x / y 최대값 )
- THR_M, THR_N : 스레드 하나가 담당는 C 타일 내 미니 타일 크기
- 한 스레드는 THR_M X THR_N 개의 (m, n) 출력 요소를 로컬 레지스터에 accum
2. (1) Smoke 커널 - 작은 문제용
__global__ void gemm_bias_act_f32_smoke(GemmBiasActParams p) {
const int m = blockIdx.y * blockDim.y + threadIdx.y;
const int n = blockIdx.x * blockDim.x + threadIdx.x;
if (m >= p.M || n >= p.N) return;
- 전형ㅈ억인 naive GEMM 인덱싱
- m : 전역 row index
- n : 전역 col index
- 범위 밖의 retrun 으로 결계 조건 처리
const float* __restrict__ A = reinterpret_cast<const float*>(p.A);
const float* __restrict__ B = reinterpret_cast<const float*>(p.B);
const float* __restrict__ C = reinterpret_cast<const float*>(p.C);
float* __restrict__ D = reinterpret_cast<float*>(p.D);
- params 에 들어있는 void* 나 generic ptr 을 f32 로 캐스팅
- __restrict__ 로 aliasing 없다고 힌트 -> 컴파일러 최적화 도움
float acc = 0.f;
for (int k = 0; k < p.K; ++k) {
float a = A[m * p.lda + k];
float b = B[k * p.ldb + n];
acc = fmaf(a, b, acc);
}
- 한 스레드가 C[m,n] 하나를 담당
- row-majore A, B
- fmaf 로 FMA 사용
float pre = p.alpha * acc;
if (p.beta != 0.f && C) pre = fmaf(p.beta, C[m * p.ldc + n], pre);
pre += load_bias(p, m, n);
D[m * p.ldd + n] = apply_act_runtime(pre, p.act, p.leaky_slope);
}
- pre : epilogue 이전 "pre-activaion" 값
- alpha * acc
- beta * C add (옵션)
- bias add
- activation 은 런타임스위치, ActKind 에 따라 ReLU, GELU 등의 적용
static inline void launch_gemm_bias_act_f32_smoke(const GemmBiasActParams& p, cudaStream_t s) {
dim3 block(16, 16);
dim3 grid((p.N + block.x - 1) / block.x, (p.M + block.y - 1) / block.y);
AI_NVTX_RANGE("regemm.smoke", nvtx::Color::Gray);
gemm_bias_act_f32_smoke<<<grid, block, 0, s>>>(p);
}
- 16 X 16 블록,
- 작은 문제에서 tile 보다 더 나은 성능 보이도록
3. (2) Tiled FWD 커널
3-1. 블록/스레드 인덱싱
template<int BM_, int BN_, int BK_, ActKind AK, BiasMode BMmode, bool HasC>
__global__ void gemm_bias_act_f32_tiled_kernel(GemmBiasActParams p) {
const int m0 = blockIdx.y * BM_;
const int n0 = blockIdx.x * BN_;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tm0 = m0 + ty * THR_M;
const int tn0 = n0 + tx * THR_N;
- 블록이 담당하는 C 타일릐 좌상단 좌표 (m0, n0)
- 블록 내 스레드 좌표 (tx, ty)에 따라, 각 스레드가 맡는 "시작 row/col"
- 스레드 (tx, ty) 는 (tm0 ... tm0+THR_M-1, tn0 ... tn0+THR_N-1) 범위의 C 요소 담당
3-2 shared memory 버퍼
#if REGEMM_USE_DB
__shared__ float As[2][BM_][BK_ + PADK];
__shared__ float Bs[2][BK_][BN_ + PADN];
#else
__shared__ float As[1][BM_][BK_ + PADK];
__shared__ float Bs[1][BK_][BN_ + PADN];
#endif
- A 타일 : BM_ X BK_ / B 타일 : BK_ X BN_
- 더블 버퍼링 옵션
- As[2][...] / Bs[2][...] : 두 스테이지를 번갈아 쓰면서 load-compute overlap 시도
- stage 변수로 현재/다음 스테이지 선택
- PADK, PADN 만큼 열 패딩 - > bank conflict 완화
3-3 레지스터 accum 초기화
float acc[THR_M][THR_N];
#pragma unroll
for (int i = 0; i < THR_M; ++i)
#pragma unroll
for (int j = 0; j < THR_N; ++j)
acc[i][j] = 0.f;
- 각 스레드의 로컬 output 블록 : THR_M X THR_N 개
- 모두 0으로 초기화
3-4 전역 포인터 캐스팅
const float* __restrict__ A = reinterpret_cast<const float*>(p.A);
const float* __restrict__ B = reinterpret_cast<const float*>(p.B);
const float* __restrict__ C = reinterpret_cast<const float*>(p.C);
float* __restrict__ D = reinterpret_cast<float*>(p.D);
3-5 A/B 타일 로더 람다
auto load_A_tile = [&](int stage, int k0) {
const int tid = ty * TDX + tx;
const int elems = BM_ * BK_;
for (int e = tid; e < elems; e += (TDX * TDY)) {
const int r = e / BK_;
const int c = e % BK_;
const int gm = m0 + r;
const int gk = k0 + c;
float v = 0.f;
if (gm < p.M && gk < p.K) v = A[gm * p.lda + gk];
As[stage][r][c] = v;
}
};
- 블록 내 모든 스레드를 1D tid 로 펴서, BM_ X BK_ 개의 A 타일 요소를 나눠 로드,
- e 는 해당 타일 내 평면 인덱스
- 글로벌 좌표
- 경계 밖이면 0 패딩
auto load_B_tile = [&](int stage, int k0) {
const int tid = ty * TDX + tx;
const int elems = BK_ * BN_;
for (int e = tid; e < elems; e += (TDX * TDY)) {
const int r = e / BN_;
const int c = e % BN_;
const int gk = k0 + r;
const int gn = n0 + c;
float v = 0.f;
if (gk < p.K && gn < p.N) v = B[gk * p.ldb + gn];
Bs[stage][r][c] = v;
}
};
- B 타일도 동일 패턴, BK_ X BN_ 형태에 맞게 변경
3-6 초기 스테이지 로드 + 동기화
int stage = 0;
if (p.K > 0) {
load_A_tile(stage, 0);
load_B_tile(stage, 0);
__syncthreads();
}
- k0 = 0 인 첫 K 타일을 shared 로 올리고, 모든 스레드 동기화
3-7 K 타일 루프 + 더블 버퍼링
for (int k0 = 0; k0 < p.K; k0 += BK_) {
#if REGEMM_USE_DB
const int next = stage ^ 1;
if (k0 + BK_ < p.K) {
load_A_tile(next, k0 + BK_);
load_B_tile(next, k0 + BK_);
}
#endif
- 현재 stage 의 타일을 compute 하는 동안, 다음 K 타일을 다른 stage 에 미리 로드
- SM 레벨에서 어느 정도 overlap 기대
#pragma unroll
for (int kk = 0; kk < BK_; ++kk) {
float a_vec[THR_M];
#pragma unroll
for (int i = 0; i < THR_M; ++i) {
const int rm = tm0 + i;
a_vec[i] = (rm < p.M) ? As[stage][rm - m0][kk] : 0.f;
}
- 현재 K 타일 안에서 kk(로컬 k index) 를 돌면서 micro-kernel 수행
- a_vec[i]
- 스레드가 담당하는 THR_M 개의 row 에 대해, 해당 kk 칼럼의 값을 벡터로 뽑음
- rm - m0 로 타일 내부 row index 변환
float b_vec[THR_N];
#pragma unroll
for (int j = 0; j < THR_N; ++j) {
const int cn = tn0 + j;
b_vec[j] = (cn < p.N) ? Bs[stage][kk][cn - n0] : 0.f;
}
- b_vec[j]
- 스레드가 담당하는 THR_N 개의 col 에 대해, kk row 의 값을 벡터로 뽑음
#pragma unroll
for (int i = 0; i < THR_M; ++i)
#pragma unroll
for (int j = 0; j < THR_N; ++j)
acc[i][j] = fmaf(a_vec[i], b_vec[j], acc[i][j]);
}
- 진짜 micro-kernel
- THR_M X THR_N 레지스터 블록에 대해 outer-product accumulate
__syncthreads();
#if REGEMM_USE_DB
stage ^= 1;
#endif
}
- 한 K 타일 끝나면 동기화, 다음 루프 iteratino 에서 stage 토글
3-8 PerN bias 프리패치
float bias_j[THR_N];
#pragma unroll
for (int j = 0; j < THR_N; ++j) {
const int n = tn0 + j;
bias_j[j] = (n < p.N) ? load_bias(p, 0, n) : 0.f;
}
- BiasMode == PerN 인 경우를 위해, col 방향 bias 를 미리 캐싱
- load_bias(p, 0, n) 에서 m 은 의미없는 자리로 쓰고 있을 가능성이 큼
using EP = Epilogue<AK, BMmode, HasC, /*SaveZ*/false>;
const int ldc = p.ldc, ldd = p.ldd;
- compile-time 템플릿 기반 epilogue 정책
- AK : activation compile-time 선택
- BMmode : bias 모드 compile-time
- HasC : C add 여부
- SaveZ-false : pre-activation 저장 안 함
3-9 최종 쓰기 루프
#pragma unroll
for (int i = 0; i < THR_M; ++i) {
const int m = tm0 + i;
if (m >= p.M) continue;
float bias_m_cached = 0.f;
if constexpr (BMmode == BiasMode::PerM) {
bias_m_cached = load_bias(p, m, 0);
}
- 스레드의 각 row 에 대해
- out-of-range 면 continue
- perM bias 인 경우, row-wise bias 를 한 번만 로드해서 캐시
#pragma unroll
for (int j = 0; j < THR_N; ++j) {
const int n = tn0 + j;
if (n < p.N) {
EP::apply(
/*D*/ reinterpret_cast<float*>(p.D), ldd,
/*C*/ reinterpret_cast<const float*>(p.C), ldc,
/*Z*/ nullptr, 0,
/*p*/ p,
/*m*/ m, /*n*/ n,
/*acc*/ acc[i][j],
/*bias_j*/ (BMmode == BiasMode::PerN ? bias_j[j] : 0.f),
/*bias_m*/ bias_m_cached
);
}
}
}
}
- Epilogue 정책 객체에 모든 정보 전달
- (m, n) 위치
- acc 값
- PerN / PerM bias 값
- C pointer & beta, alpha, activation 등은 p 안에
- EP::apply 안에서 alpba acc + Beta C + bias + activation 까지 완전히 처리
4. 런처들
4-1. FWD tiled 런처
template<ActKind AK, BiasMode BMmode, bool HasC>
static inline void launch_fwd_cfg(const GemmBiasActParams& p, cudaStream_t s) {
dim3 block(TDX, TDY);
dim3 grid((p.N + BN - 1) / BN, (p.M + BM - 1) / BM);
AI_NVTX_RANGE("regemm.tiled", nvtx::Color::Blue);
gemm_bias_act_f32_tiled_kernel<BM, BN, BK, AK, BMmode, HasC><<<grid, block, 0, s>>>(p);
}
- 템플릿에서 activation/bias/HasC 를 compile-time 으로 확정한 버전
- 그 위에서 BiasMode 를 런타임에서 분기
template<ActKind AK, bool HasC>
static inline void launch_fwd_cfg_bm(const GemmBiasActParams& p, BiasMode bm, cudaStream_t s) {
switch (bm) {
case BiasMode::PerM: launch_fwd_cfg<AK, BiasMode::PerM, HasC>(p, s); break;
case BiasMode::PerN: launch_fwd_cfg<AK, BiasMode::PerN, HasC>(p, s); break;
case BiasMode::Full: launch_fwd_cfg<AK, BiasMode::Full, HasC>(p, s); break;
case BiasMode::None:
default: launch_fwd_cfg<AK, BiasMode::None, HasC>(p, s); break;
}
}
- biasMode 는 런타임 enum - 스위치 - 각 case 마다 템플릿 인스턴스 선택
static inline void launch_gemm_bias_act_f32_tiled(const GemmBiasActParams& p, cudaStream_t s) {
const BiasMode bm = to_bias_mode(p.bias_kind);
const bool hasC = (p.beta != 0.f && p.C != nullptr);
switch (p.act) {
case ActKind::ReLU:
if (hasC) launch_fwd_cfg_bm<ActKind::ReLU, true >(p, bm, s);
else launch_fwd_cfg_bm<ActKind::ReLU, false>(p, bm, s);
break;
...
case ActKind::None:
default:
if (hasC) launch_fwd_cfg_bm<ActKind::None, true >(p, bm, s);
else launch_fwd_cfg_bm<ActKind::None, false>(p, bm, s);
break;
}
}
- ActKind 는 런타임 enum - 스위치 - 템플릿 ActKind 인자로 고정
- C add 여부는 bool hasC 로 결정 - HasC 템플릿에 반영
- 이렇게 해서 핫패스에서는 완전한 compile-time epilogue specialization