1. FP32 GEMM 커널 분석
__global__
void fp32_gemm_kernel(const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int N, int K, int iters)
- _global_
- GPU 에서 실행되는 커널 함수, host 에서 <<< >>> 로 호출
- const float* __restrict__ A
- A 는 읽기 전용 포인터, __restrict__ 는 이 포인터를 통해서만 접근한다는 힌트, 별칭 alias 없으니 최적화 수행
- B, C 도 같은 개념
- M, N, K, iters : 행렬 크기와 반복 횟수
int row = blockIdx.y * blockDim.y + threadIdx.y; // 0..M-1
int col = blockIdx.x * blockDim.x + threadIdx.x; // 0..N-1
if (row >= M || col >= N) return;
- grid 가 딱 맞게 안 나누어 떨어질 수 있으니, 범위 체크
- 유효 범위 밖 스레드는 바로 리턴
// 여러 번 반복해서 연산량 증가
for (int it = 0; it < iters; ++it) {
float acc = 0.0f;
- 커널의 여러 번 실행 대신, 커널 내부에서 iters 번 반복해서 연산량 / 실행 시간을 늘림
- acc = 누적 결과
// 단순 i-k 루프
for (int k = 0; k < K; ++k) {
float a = A[row * K + k];
float b = B[k * N + col];
acc += a * b;
}
// 마지막 it 기준으로 덮어쓰기 (결과 자체는 크게 중요하지 않음)
C[row * N + col] = acc;
}
- 각 it 마다 다시 누적, 매 반복마다 C 에 쓰고 있음, 얼마나 많은 연산, 실행 시간 비교 역할 수행
WMMA Tensor Core GEMM 커널 분석
__global__
void wmma_gemm_kernel(const half* __restrict__ A_half,
const half* __restrict__ B_half,
float* __restrict__ C,
int M, int N, int K, int iters)
{
#if __CUDA_ARCH__ >= 700
- 컴파일 타겟 아키텍처가 7.0 이상인 경우에만 활성화
- FP16 버전, A_half, B_half
- C 는 FP32 accumulator
// warp 단위 tile
constexpr int WMMA_M = 16;
constexpr int WMMA_N = 16;
constexpr int WMMA_K = 16;
- WMMA 기본 타일 크기 설정 : 16 x 16 x 16
- 하나의 MMA 연산이 M x N 출력, K 축을 16 씩 누적
// warp 하나가 16x16 tile 하나 담당
int tile_row = blockIdx.y; // 0..(M/16-1)
int tile_col = blockIdx.x; // 0..(N/16-1)
- grid 의 blockIdx.y, blockIdx.x 각각이 하나의 16 x 16 타일으르 ㄷ마당
- 블록 하나 = warp 하나 = 타일 하나
// warp 내 lane id
int lane_id = threadIdx.x % 32;
// blockDim.x는 32로 가정 (warp 1개)
if (tile_row * WMMA_M >= M || tile_col * WMMA_N >= N)
return;
- lane_id 는warp 내 스레드 인덱스, 현재 직접 사용중은 아님
- blockDim.x == 32 가정 : 한 블록에 딱 한 warp
- grid 크기와 M/N 이 항상 딱 맞더라도, 방어적 범위 체크
// 반복 횟수만큼 GEMM 수행
for (int it = 0; it < iters; ++it) {
// C fragment 0으로 초기화
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K,
half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K,
half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K,
float> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
- wmma::fragment:
- Tensor Core 연산에 들어가는 조각, matrix tile 을 warp 의 레지스터에 담는 추상 타입
- matrix_a / matrix_b : 입력 해열 조각
- accumulator : C 타일 ( 출력 / 누적 )
- row_major : 메모리 layout 정보
- wmma::fill_fragment(c_frag, 0.0f);
// K dimension을 16씩 잘라서 누적
for (int kk = 0; kk < K; kk += WMMA_K) {
const half* tile_ptr_A = A_half + (tile_row * WMMA_M) * K + kk;
const half* tile_ptr_B = B_half + kk * N + (tile_col * WMMA_N);
- 일반 GEMM 에서 C_tile = A_tile * B_tile 을 타일 기준으로 구현하는 모습
- 현재 타일이 시작하는 행 번호
- kk 는 K 축에서의 현재 16 슬라이스 offset
- tile_ptr_A
- A 를 row-major 로 볼 때,
- 시작 행 = tile_row * 16
- 시작 열 = kk
- tile_ptr_B
- 시작 행 = kk,
- 시작 열 = tile_col * 16
// row-major / row-major로 load
wmma::load_matrix_sync(a_frag, tile_ptr_A, K);
wmma::load_matrix_sync(b_frag, tile_ptr_B, N);
- 글로벌 메모리에서 half 행렬 조각을 warp 레지스터에 로드
- 세 번째 인자는 leading dimension, row-major 에 서 한 행의 stride
// C += A * B
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
- 핵심 Tensor Core MMA 연산
- c_frag = a_frag * b_frag + c_frag 을 한 번에 수행
- 내부적으로 FP16 입력을 FP32 accumulate 로 처리
// C에 store
float* tile_ptr_C = C + (tile_row * WMMA_M) * N + (tile_col * WMMA_N);
wmma::store_matrix_sync(tile_ptr_C, c_frag, N, wmma::mem_row_major);
}
- tile_ptr_C
- C row-major 에서 현재 타일의 좌상단 포인터
- store_matrix_sync
- 레지스터에 있는 C 프레그먼트를 global 메모리로 저장
float - half 변환 커널
__global__
void to_half_kernel(const float* __restrict__ src,
half* __restrict__ dst,
int n)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
dst[idx] = __float2half(src[idx]);
}
}
- 1D grid + 1D block
- idx = 전체 배열에서 이 스레드가 담당하는 원소 인덱스