본문 바로가기

GPU-KERNEL

Micro_gemm_4x4 - cuBLAS micro GEMM skeleton

  • 블록당 처리 C 타일 크기 : 64 x 64 (BLOCK_M, BLOCK_N)
  • 한 번에 처리하는 K 타일 크기 : 8 (BLOCK_K)
  • 한 스레드가 계산하는 마이크로 타일 4 x 4 (THREAD_TILE_M, THREAD_TILE_N)
  • 스레드 블록 : (16 x 16) = 256 threads
  • doublue-buffered shared memory : As[2], Bs[2]

 

각기 대응되는 부분

1. FMA 패턴 최적화 포인트

#pragma unroll
for (int i = 0; i < THREAD_TILE_M; ++i) {
    #pragma unroll
    for (int j = 0; j < THREAD_TILE_N; ++j) {
        acc[i][j] = fmaf(a_frag[i], b_frag[j], acc[i][j]);
    }
}
  • acc[4][4] 를 완전 언롤해서
  • 컴파일러가 FFMA 들을 골고루 스케줄하게 만들기 좋은 형태

여기서도 패턴을 더 꼬아서 

acc00, acc01, acc10, acc11 순서 섞는 식으로 더 하드코딩 가능

 

2. Unroll 포인트

  • BLOCK_K 방향 루프에 #pragma unroll
  • micro-tile 루프에도 #praga unroll
#pragma unroll
for (int kk = 0; kk < BLOCK_K; ++kk) { ... }

BLOCK_K 템플릿/constexpr 로 잡았기 때문에, 실제 SASS 에선 루프가 풀린FFMA 스트림으로 나와야 함

 

3. Pipeline (double buffering)

int buf = 0;
// 첫 타일 로드
load_AB_tile(k0, buf, (k0 < K));
__syncthreads();

for (; k0 < K; k0 += BLOCK_K) {
    int next_k0 = k0 + BLOCK_K;
    int next_buf = buf ^ 1;

    if (next_k0 < K) {
        load_AB_tile(next_k0, next_buf, true); // 다음 타일 로드
    }

    // 현재 buf 로 compute
    ...

    __syncthreads(); // 다음 buf 로딩 완료 보장
    buf = next_buf;
}
  • 현재 버퍼로 FMA 하는 동안,
  • 다음 버퍼에 다음 K 타일을 로드 - latnecy 숨기기

Ampere cp.async 넣고 싶으면 load_AB_tile 안쪽을 바꿔서 cp.async 로 바꾸면 됨

 

1. 타일 파라미터

constexpr int THREAD_TILE_M = 4; // 스레드당 마이크로 타일 M
constexpr int THREAD_TILE_N = 4; // 스레드당 마이크로 타일 N
  • 한 스레드가 책임지는 C 부분 크기, 4 x 4
  • 즉, 하나의 스레드가 16개의 output element 를 누적해서 만든다.
constexpr int THREADS_PER_BLOCK_M = BLOCK_M / THREAD_TILE_M; // 64 / 4 = 16
constexpr int THREADS_PER_BLOCK_N = BLOCK_N / THREAD_TILE_N; // 64 / 4 = 16
  • M 방향 16개 스레드,
  • N 방향 16개 스레드, 총 16 x 16, 256 threads per block

 

2. 커널 시그니처

__global__ void micro_gemm_4x4_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    int M, int N, int K,
    float alpha, float beta)
  • 표준 GEMM 형태 : C = alpha * A + beta * C
  • __restrict__ 
    • 포인터들이 서로 alias 안 한다는 힌트, 컴파일 최적화 도움
  • M, N, K - 행렬 크가 (M, K), (K, N), (M, N)

 

3. 블록 / 스레드가 담당하는 타일 좌표 계산

 

    int block_row = blockIdx.y * BLOCK_M;
    int block_col = blockIdx.x * BLOCK_N;
  • 현재 블록이 담당하는 C 타일의 왼쪽 위 global 좌표
  • blockIdx.y 방향으로 M, blockIdx.x 방향으로 N 진행
    int tid_row = threadIdx.y; // [0, THREADS_PER_BLOCK_M)
    int tid_col = threadIdx.x; // [0, THREADS_PER_BLOCK_N)
  • 블록 안에서 현재 스레드의 2D 인덱스
  • (y, x) = (0~15, 0~15)
    int thread_row = block_row + tid_row * THREAD_TILE_M;
    int thread_col = block_col + tid_col * THREAD_TILE_N;
  • 이 스레드가 담당하는 4x4 마이크로 타일의 왼쪽 위 global 좌표

 

4. 레지스터 accumulator 준비

float acc[THREAD_TILE_M][THREAD_TILE_N] = {0.0f}
  • 이 스레드가 계산할 4 x 4 구역의 누적 값
  • 전부 0 으로 초기화

 

5. shared memory double buffer 생성

extern __shared__ float shared_mem[];
  • 커널 런치 때 지정한 shared mem 전체를 shared_mem 로 받음
float *As = shared_mem;
float *Bs = As + 2 * BLOCK_M * BLOCK_K;
  • As : A 타일용 shared 메모리 시작 주소
    • 2 * BLOCK_M * BLOCK_K - double buffering ( buf = 1, 0 두 개 버퍼 )
  • Bs : 그 다음 공간부터 B 타일용 shared 메모리
int buf = 0;
  • 현재 사용 중인 버퍼 인덱ㄷ스
    • 한 타일 계산할 때마다 buf ^= 1 로 번갈아 사용

 

6. A/B 타일 로드 lambda

auto load_AB_tile = [&] (int k0, int buf_idx, bool valid) {
	if (!valid) return;
  • k0 : 이 타일이 시작하는 K 인덱스
  • buf_idx : 0 또는 1
  • valid : 경계 넘어가는지 체크용, 유효하지 않으면 바로 return 

 

A 타일 로드 부분

for (int i = tid_row; i < BLOCK_M; i += THREADS_PER_BLOVK_M) {
	for (int kk = tid_col; kk < BLOCK_K; kk += THREADS_PER_BLOCK_N) {
		int global_row = block_roww + i;
		int global_k = k0 + kk;
        
		float val = 0.0f;
		if (global_row < M & global_k < K) {
			val = A[global_row * K + globak_k];
		}
	As[buf_idx * BLOCK_M * BLOCK_K + i * BLOCK_K + kk] = val;
	}
}
  • BLOCK_M x BLOCK_K 크기의 A 타일을 shared 로 옮긴다.
  • 2중 for 에서
    • i : 행 방향
    • kk : K 방향
  • 현재 스레드가 (i, kk) 패턴으로 조금씩 나눠서 로드
    • i += THREAD_PER_BLOCK_M / kk += THREADS_PER_BLOCK_N 해서 모든 스레드가 골고루 영역 분담
  • out-of-bounds 시 0 채워넣기 - 경계 블록에서도 안전

 

B 타일 로드 동일

 

7. 첫 번째 타일 pre-load

int k0 = 0;
load_AB_tile(k0, buf, (k0 < K));
__syncthreads();
  • 처음 K=0 부터 시작하는 타일을 buf=0 에 로드
  • __syncthreads() 로 모든 스레드가 load 끝나길 기다림 - 이후 모두 shared 데이터 사용 가능

 

8. K 루프 (파이프라인 + FMA + 언롤 핵심)

for (; k0 < K; k0 += BLOCK_K) {
	int next_k0 + BLOCK_K;
	int next_buf = buf ^ 1;
  • K 방향으로 BLOCK_K 씩 증가하면서 슬라이스를 하나씩 처리
  • 다음 슬라이스 시작 인덱스 next_k0, 다음 버퍼 next_buf

 

다음 타일 prefectch

if (next_k0 < K){
	load_AB_tile(next_k0, next_buf, true);
}
  • 현재 타일로 계산하는 동시에, 다음 타일을 다른 버퍼에 로드
  • 나중에 cp.async 넣으면 이 부분에 들어감

 

현재 타일로 FMA 준비

float a_frag[THREAD_TILE_M];
float b_frag[THREAD_TILE_N];
  • 현재 스레드의 4 행, 4 열을 레지스터에 담을 버퍼

 

BLOCK_K 방향 완전 언롤 + FMA

 

#pragma unroll
for (int kk = 0; kk < BLOCK_K; ++kk) {
  • kk 는 현재 타일 내 K 인덱스 (0 ~ 7)
  • #pragma unroll -  BLOCK_K 가 상수라서 실제 코드에선 루프가 완전히 풀림
#pragma unroll
for (int i = 0; i < THREAD_TILE_M; ++i) {
    int row_in_block = tid_row * THREAD_TILE_M + i;
    a_frag[i] = As[buf * BLOCK_M * BLOCK_K + row_in_block * BLOCK_K + kk];
}
  • 이 스레드가 담당하는 4개 row에 대해
  • 현재 kk 위치의 A 값을 shared 에서 읽어와서 a_frag 에 넣는다.
#pragma unroll
for (int j = 0; j < THREAD_TILE_N; ++j) {
    int col_in_block = tid_col * THREAD_TILE_N + j;
    b_frag[j] = Bs[buf * BLOCK_K * BLOCK_N + kk * BLOCK_N + col_in_blocck];
}
  • B 도 마찬가지로, 담당하는 4개 col 의 kk 위치 값을 b_frag 에 저장

 

FMA 패턴 구간

#pragma unroll
for (int i = 0; i < THREAD_TILE_M; ++i) {
    #pragma unroll
    for (int j = 0; j < THREAD_TILE_N; ++j) {
        acc[i][j] = fmaf(a_frag[i], b_frag[j], acc[i][j]);
    }
}
  • 이제 한 번의 kk 에 대해
    • i, k : 0 ... 3 돌면서
    • acc[i][j] += a_frag_i * b_frag_j 수행
  • 이중 루프도 언롤
    • 컴파일러는 16개 FFMA 펼쳐서 스케줄
    • 이것이 FMA 패턴 최적화 + 언롤이 들어가는 부분

 

다음 버퍼 사용 준비 

        }

        __syncthreads();  // 다음 buf 로딩 끝났는지 보장
        buf = next_buf;
    }
  • 루프 한 바퀴 끝나면
    • next_buf 에 로드 해둔 다음 타일이 준비됐는지 __syncthreads() 로 보장
    • 이후부터는 buf 를 next_buf 로 전환해서 다음 슬라이스 계산

 

파이프라인 구조 요약

  1. 첫 타일 로드
  2. 루프 안
    1. buf 타일로 계산하는 동안
    2. next_buf 에 다음 타일 로드
  3. sync 후 buf - next_buf 교체

 

9. 최종 C 쓰기 (alpha, beta 포함)

    #pragma unroll
    for (int i = 0; i < THREAD_TILE_M; ++i) {
        int global_row = thread_row + i;
        if (global_row >= M) continue;
  • 4 개 row 에 대해
    • 행 전체가 범위 밖이면 skip

 

        #pragma unroll
        for (int j = 0; j < THREAD_TILE_N; ++j) {
            int global_col = thread_col + j;
            if (global_col >= N) continue;
  • 각 row 에서 4개의 col에 대해
    • 범위를 넘어가면 skip

 

            float c_val = acc[i][j] * alpha;

            if (beta != 0.0f) {
                c_val += beta * C[global_row * N + global_col];
            }

            C[global_row * N + global_col] = c_val;
  • 최종 GEMM 포맷
    • c_val = alpha * (acc) + beta * existing_C
  • beta == 0 인 경우에는 기존 C 읽지 않고 덧셈 스킵