본문 바로가기

GPU-KERNEL

Tensorcore_wmma_vs_fp32 test code

1. FP32 GEMM 커널 분석

__global__
void fp32_gemm_kernel(const float* __restrict__ A,
                      const float* __restrict__ B,
                      float* __restrict__ C,
                      int M, int N, int K, int iters)
  • _global_
    • GPU 에서 실행되는 커널 함수, host 에서 <<< >>> 로 호출
  • const float* __restrict__ A
    • A 는 읽기 전용 포인터, __restrict__ 는 이 포인터를 통해서만 접근한다는 힌트, 별칭 alias 없으니 최적화 수행
  • B, C 도 같은 개념
  • M, N, K, iters : 행렬 크기와 반복 횟수

 

    int row = blockIdx.y * blockDim.y + threadIdx.y; // 0..M-1
    int col = blockIdx.x * blockDim.x + threadIdx.x; // 0..N-1
  • 2D grid + 2D block 구성

 

    if (row >= M || col >= N) return;
  • grid 가 딱 맞게 안 나누어 떨어질 수 있으니, 범위 체크
  • 유효 범위 밖 스레드는 바로 리턴

 

    // 여러 번 반복해서 연산량 증가
    for (int it = 0; it < iters; ++it) {
        float acc = 0.0f;
  • 커널의 여러 번 실행 대신, 커널 내부에서 iters 번 반복해서 연산량 / 실행 시간을 늘림
  • acc = 누적 결과

 

        // 단순 i-k 루프
        for (int k = 0; k < K; ++k) {
            float a = A[row * K + k];
            float b = B[k * N + col];
            acc += a * b;
        }
  • 전형적인 행렬곱 인덱싱

 

        // 마지막 it 기준으로 덮어쓰기 (결과 자체는 크게 중요하지 않음)
        C[row * N + col] = acc;
    }
  • 각 it 마다 다시 누적, 매 반복마다 C 에 쓰고 있음, 얼마나 많은 연산, 실행 시간 비교 역할 수행

 

WMMA Tensor Core GEMM 커널 분석

__global__
void wmma_gemm_kernel(const half* __restrict__ A_half,
                      const half* __restrict__ B_half,
                      float* __restrict__ C,
                      int M, int N, int K, int iters)
{
#if __CUDA_ARCH__ >= 700
  • 컴파일 타겟 아키텍처가 7.0 이상인 경우에만 활성화
  • FP16 버전, A_half, B_half
  • C 는 FP32 accumulator

 

    // warp 단위 tile
    constexpr int WMMA_M = 16;
    constexpr int WMMA_N = 16;
    constexpr int WMMA_K = 16;
  • WMMA 기본 타일 크기 설정 : 16 x  16 x 16
  • 하나의 MMA 연산이 M x N 출력, K 축을 16 씩 누적

 

    // warp 하나가 16x16 tile 하나 담당
    int tile_row = blockIdx.y; // 0..(M/16-1)
    int tile_col = blockIdx.x; // 0..(N/16-1)
  • grid 의 blockIdx.y, blockIdx.x 각각이 하나의 16 x 16 타일으르 ㄷ마당
  • 블록 하나 = warp 하나 = 타일 하나

 

    // warp 내 lane id
    int lane_id = threadIdx.x % 32;

    // blockDim.x는 32로 가정 (warp 1개)
    if (tile_row * WMMA_M >= M || tile_col * WMMA_N >= N)
        return;
  • lane_id 는warp 내 스레드 인덱스, 현재 직접 사용중은 아님
  • blockDim.x == 32 가정 : 한 블록에 딱 한 warp
  • grid 크기와 M/N 이 항상 딱 맞더라도, 방어적 범위 체크

 

    // 반복 횟수만큼 GEMM 수행
    for (int it = 0; it < iters; ++it) {
        // C fragment 0으로 초기화
        wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K,
                       half, wmma::row_major> a_frag;
        wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K,
                       half, wmma::row_major> b_frag;
        wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K,
                       float> c_frag;

        wmma::fill_fragment(c_frag, 0.0f);
  • wmma::fragment:
    • Tensor Core 연산에 들어가는 조각, matrix tile 을 warp 의 레지스터에 담는 추상 타입
    • matrix_a / matrix_b : 입력 해열 조각
    • accumulator : C 타일 ( 출력 / 누적 )
  • row_major : 메모리 layout 정보
  • wmma::fill_fragment(c_frag, 0.0f);
    • C 레지스터타일 전체를 0 으로 초기화

 

        // K dimension을 16씩 잘라서 누적
        for (int kk = 0; kk < K; kk += WMMA_K) {
            const half* tile_ptr_A = A_half + (tile_row * WMMA_M) * K + kk;
            const half* tile_ptr_B = B_half + kk * N + (tile_col * WMMA_N);
  • 일반 GEMM 에서 C_tile = A_tile * B_tile 을 타일 기준으로 구현하는 모습
  • 현재 타일이 시작하는 행 번호
  • kk 는 K 축에서의 현재 16 슬라이스 offset
  • tile_ptr_A
    • A 를 row-major 로 볼 때,
    • 시작 행 = tile_row * 16
    • 시작 열 = kk
  • tile_ptr_B
    • 시작 행 = kk,
    • 시작 열 = tile_col * 16

 

            // row-major / row-major로 load
            wmma::load_matrix_sync(a_frag, tile_ptr_A, K);
            wmma::load_matrix_sync(b_frag, tile_ptr_B, N);
  • 글로벌 메모리에서 half 행렬 조각을 warp 레지스터에 로드
  • 세 번째 인자는 leading dimension, row-major 에 서 한 행의 stride

 

            // C += A * B
            wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
        }
  • 핵심 Tensor Core MMA 연산
  • c_frag = a_frag * b_frag + c_frag 을 한 번에 수행
  • 내부적으로 FP16 입력을 FP32 accumulate 로 처리

 

        // C에 store
        float* tile_ptr_C = C + (tile_row * WMMA_M) * N + (tile_col * WMMA_N);
        wmma::store_matrix_sync(tile_ptr_C, c_frag, N, wmma::mem_row_major);
    }
  • tile_ptr_C
    • C row-major 에서 현재 타일의 좌상단 포인터
  • store_matrix_sync
    • 레지스터에 있는 C 프레그먼트를 global 메모리로 저장

 

 

float - half 변환 커널

__global__
void to_half_kernel(const float* __restrict__ src,
                    half* __restrict__ dst,
                    int n)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        dst[idx] = __float2half(src[idx]);
    }
}
  • 1D grid + 1D block
  • idx = 전체 배열에서 이 스레드가 담당하는 원소 인덱스