본문 바로가기

GPU-KERNEL

누산 개념을 Grid-Block-Warp-Thread 계층에 끼워넣기

커널을 실행할 때는 단일 커널 함수로 전체 C 를 완성시켜야 하는지?

1. 기본 mental model : 한 개의 GEMM 커널로 C 전체를 완성한다.

일반적인 고성능 GEMM 은 이렇게 생각,

  • 하나의 커널 함수
  • 하나의 kernel launch 로 
  • C[M x N] 전체를 블록 타일로 나눠서 전부 계산

즉, 

  • Grid 의 각 Block -> C 의 한 block tile (예: 128 x 128) 담당
  • Block 안의 warps -> 그 타일의 warp tile 들 (예: 64 x 64) 담당
  • Warp 안 Threads -> thread tile (예: 4 x 4, 2 x 4) 담당
  • BK 루프에서 모두 자기 tile 에 누산해서 그 타일을 완성

단일 커널로 전체 C 완성이라는 이해는 기본적인 모델로 맞다고 보면 된다.

 

2. 그 누산을 계층에 정확히 대입하면 이렇게 된다.

가정:

  • Block tile : BM = 128, BN = 128
  • Warp tile : WM = 64, WN = 64
  • Thread tile : TM = 4, TN = 4
  • K tile : BK = 32

(1) Grid / Block 레벨

// 각 block이 C의 (BM x BN) 타일 하나 담당
int block_row = blockIdx.y;  // 0..ceil(M/BM)-1
int block_col = blockIdx.x;  // 0..ceil(N/BN)-1

int C_row0 = block_row * BM;
int C_col0 = block_col * BN;
  • 이 Block 하나가 C[C_row0 : C_row0 + 128, C_col0 : C_col0+128] 전체를 담당.
  • 이 128 * 128 영역이 BK 타일 누산으로 완성되는 C 의 부분(블록 타일)

(2) Block 안 Warp 레벨

int warp_id = (threadIdx.y * blockDim.x + threadIdx.x) / warpSize;
// warp_id 0..(warpsPerBlock-1)

int warp_row = warp_id / warps_per_row;
int warp_col = warp_id % warps_per_row;

int C_warp_row0 = C_row0 + warp_row * WM;
int C_warp_col0 = C_col0 + warp_col * WN;
  • 하나의 warp 가 block tile 안에서 64 * 64 warp tile 담당
  • 이 warp tile 도 BK 루프 전체 동안 계속 누산해서 warp tile 을 완성.

(3) Warp 안 Thread 레벨

int lane = /* 0..31 */;
int thread_m = (lane / something); // TM tile 배치
int thread_n = (lane % something);

int C_thread_row0 = C_warp_row0 + thread_m * TM;
int C_thread_col0 = C_warp_col0 + thread_n * TN;

// 이 스레드는 C의 TM x TN 조각 담당 (예: 4x4)
float acc[TM][TN] = {0};  // 레지스터에 존재
  • 이 스레드는 C 의 작은 조각 thread tile ( 4 * 4 ) 을 담당
  • 이 acc[TM][TN] 이 K/BK 반복동안 계속 누산되는 대상

(4) BK 루프 - 모든 계층에서 동시에 진행되는 K 방향 누산

for (int k0 = 0; k0 < K; k0 += BK) {
    // 1) 이 block이 담당하는 A,B의 BK 타일을 shared에 올림
    // A_tile: [BM x BK], B_tile: [BK x BN]

    __syncthreads();
    // A_tile, B_tile load 완료 보장

    // 2) warp, thread가 A_tile/B_tile에서 자기 필요한 조각만 읽어서
    //    acc[TM][TN]에 누산
    for (int kk = 0; kk < BK; ++kk) {
        float a_frag[TM];  // 이 스레드가 담당하는 A의 열 방향 조각
        float b_frag[TN];  // 이 스레드가 담당하는 B의 행 방향 조각

        // a_frag, b_frag 로딩 후
        // 작은 매트릭스 곱셈 형태로 업데이트
        for (int i = 0; i < TM; ++i)
        for (int j = 0; j < TN; ++j)
            acc[i][j] += a_frag[i] * b_frag[j];
    }

    __syncthreads(); // 다음 BK 타일 로딩 준비
}

중요한 포인트

  • Block 레벨
    • BK 루프를 다 돌면 128 * 128 C_block tile 이 완성됨 ( warp/thread 들이 각자 맡은 tile 을 자기 레지스터에서 완성해 놓은 상태)
  • Warp 레벨
    • warp 단위로 보면 자기 64 * 64 duddurdl BK 루프 끝에서 완성됨
  • Thread 레벨
    • 각 thread 는 자기 acc[TM][TN] 가 BK 루프 끝에서 완성됨, 이걸 global C 에 store
C[(C_thread_row0 + i) * ldc + (C_thread_col0 + j)] = acc[i][j];

이 전 과정이 하나의 커널 함수 안에서 일어난다.