본문 바로가기

GPU-KERNEL

GEMM kernel 개선하기 (1)

블록이 C 에서 맡는 영역

const int TILE = GEMM_TILE; // 기본 32
const int TN   = GEMM_TN;   // 기본 4

dim3 block(TILE / TN, TILE); // (TILE/TN, TILE) = (8, 32)
dim3 grid((N + TILE - 1) / TILE,
          (M + TILE - 1) / TILE);
  • 한 블록이 C 의 TILE * TILE 서브 타일을 계산
  • 블록 인덱스 - C 타일 시작 위치 
int block_row = blockIdx.y * TILE;   // C의 row 타일 시작
int block_col = blockIdx.x * TILE;   // C의 col 타일 시작
  • 스레드 인덱스
int tx = threadIdx.x;              // 0 .. TILE/TN - 1   (예: 0..7)
int ty = threadIdx.y;              // 0 .. TILE-1        (예: 0..31)
int row  = block_row + ty;         // 이 스레드가 담당하는 행
int col0 = block_col + tx * TN;    // 이 스레드가 담당하는 첫 번째 열

한 스레드는 C 요소의 행 방향으로는 1개, 열 방향으로는 TN 개를 담당함 

그래서 이름이 regN, N 축으로 레지스터 타일링 

 

Shared Memory 타일 구조

__shared__ T As[TILE][TILE+1]; // [row][k_tile]
__shared__ T Bs[TILE][TILE+1]; // [k_tile][col]

TILE * (TILE+1) 로 +1 padding : bank conflict 회피 패턴

 

K 축 타일 루프 구조

int num_tiles = (K + TILE - 1) / TILE;

for (int t = 0; t < num_tiles; ++t) {
    int k0 = t * TILE;
    ...
}
  • K 축을 TILE 단위로 잘라서 여러 타일을 순차적으로 처리
  • 매 타일마다 
    • A/B 의 해당 K-슬라이스를 shared mem 에 로드
    • 그 타일에서의 matmul 부분 (accumulate)
    • 다음 타일로 이동 (K 축 누적)

 

As / Bs 로드 패턴 (Global - Shared)

공통 인덱스

for (int j = 0; j < TN; ++j) {
    int kk   = tx * TN + j;   // 0 .. TILE-1
    int k_gl = k0 + kk;
    ...
}
  • kk : 타일 내부 k 축 인덱스
  • k_gl : 글로벌 K 인덱스

 

A 타일 로드

// A tile
if (row < M && k_gl < K) {
    As[ty][kk] = A[row * K + k_gl];
} else {
    As[ty][kk] = T(0);
}
  • 행 인덱스 : row = block_row + ty
  • 열 인덱스 : k_gl = k0 + kk

즉, block 이 담당하는 C 행 타일에 대해, 각 행의 K 축 TILE 범위를 As 로 가져옴

각 thread 는

  • 같은 row 에서
  • 서로 다른 k_gl 를 TN 개씩 로드
  • 한 블록에서 A 타일 전체를 채움

메모리 관점

  • 같은 row 에서 k_gl 이 연속이기 때문에 A 에 대해서는 행 방향으로 꽤 coalesced 된 접근 패턴

 

B 타일 로드

// B tile: ty를 k index로 사용
int col = col0 + j;
int kB  = k0 + ty;
if (kB < K && col < N) {
    Bs[ty][kk] = B[kB * N + col];
} else {
    Bs[ty][kk] = T(0);
}

ty 를 K 축 인덱스로 사용 : kB = k0 + ty

col = col0 + j - 이 스레드가 담당하는 열 중 하나

그래서 Bs 는 논리적으로

row : k축

col : 타일 내부 열 인덱스 (kk)

 

계산 부분

타일당 K 축에서의 연산

// 레지스터 accumulator: 1 x TN
T acc[TN];
#pragma unroll
for (int j = 0; j < TN; ++j) acc[j] = T(0);
...
#pragma unroll
for (int kk = 0; kk < TILE; ++kk) {
    T a = As[ty][kk];
    #pragma unroll
    for (int j = 0; j < TN; ++j) {
        int n_idx = tx * TN + j;
        T b = Bs[kk][n_idx];
        acc[j] += a * b;
    }
}
  1. a = As[ty][kk]
    • 이 스레드의 row 에 해당하는 A 값 하나를 가져옴
    • 이 값은 TN 개의 다른 열에 대해 재사용
  2. b = Bs[kk][n_idx]
    • 같은 kk 에서
    • 이 스레드가 담당하는 열 인덱스에 해당하는 B 값
  3. acc[j] += a * b
    • acc 는 각 열에 대한 누적값
    • 이 스레드는
C[row, col0]     ... C[row, col0+TN-1]

위를 모드 레지스터에 들고 다니면서 K 축 전체에 대해 누적

A 는 shared -> register 로 한 번 읽어서

B 는 shared 에서 TN 개를 읽어서 acc 에 더함

 

스레드는 TN 개의 C 셀을 들고 있고 (reg tile)

각 C 셀은 K 축을 따라 누적되며 점점 완성된다.