본문 바로가기

GPU-KERNEL

타일링 방식을 통한 GEMM 의 실제 연산 방식 이해

kernel hierarchy 에 이어서 GEMM 의 실제 연산 방식 + 타일링 알고리즘에 대해 알아야함

누산, acc 가 어떻게 이뤄지는지에 대해 궁금해서...

 

GEMM 의 전체 구조를 3개의 큰 그림으로 나누어 잡기 

  1. GEMM 수식 자체가 무엇을 의미하는지
  2. GEMM 을 GPU 에서 효율적으로 돌리기 위해 왜 타일링을 하는지
  3. 최적화된 타일링 GEMM 에서 계층 구조가 어떻게 설계되는가

 

1) GEMM = "K축 누산" 구조

최적화 GEMM 을 이해하기 위해

C[i,j] = Σ_k (A[i,k] * B[k,j])

즉,

  • A 의 행 (row i)
  • B 의 열 (col j)

이 둘을 k 방향으로 계속 곱하고 더하는 누산(accumulate)

GEMM 은 K 를 돌면서 누산하는 구조가 본질

중요한 것은

  • C[i, j] 하나는 스레드 하나가 단독으로 계산함 (다른 스레드와 partial sum 공유 없음)
  • K 가 크면
    • K 를 chunk(BK) 단위로 잘라서 반복
    • A_tile(BM * BK), B_tile(BK * BN) 를 shared memory 에 올림
    • tiles 끼리 곱해서 자기 C_tile 에 누산

즉, GEMM = K 를 타일 단위로 쪼개서 반복적으로 자기 C 타일에 누산하는 것

 

2) GPU 에서 타일링이 필요한 이유

Global memory bandwidth 는 매우 느리고 shared memory 는 매우 빠르다.

따라서

매번 A[i,k], B[k,j]를 global memory에서 읽으면 → 병목 100%

이를 해결하기 위해 재사용 가능한 범위를 잡아 tile 단위로 묶는다.

 

예: block tile ( 128 * 128 )

block 전체가 C 의 128 * 128 영역을 계산한다. 

이 block 은 

  • A 의 128 * BK 타일
  • B 의 BK * 128 타일

을 shared memory 에 올리고, warp/thread 는  이 shared tile 을 반복해 읽어서 연산한다. 

타일링의 목적 = global memory 접근 최대한 줄이고 shared memory 재사용 극대화

 

3) 최적화 GEMM 의 진짜 구조 : 3 단계 타일링 계층

(A) Block tile ( BM x BN )

  • 예 : 128 * 128
  • block 하나가 C 의 알짜배기 128 * 128 조각을 책임짐
  • block 내부 warp 들이 협동해서 이걸 완성함
  • shared memory 에 A_tile, B_tile 을 올리는 단위

(B) Warp tile ( WM x WN )

  • 예 : 64 * 64
  • 공유 메모리에 올라온 데이터 중 warp 하나가 담당할 중간 크기의 C tile
  • warp 레벨의 연산 = Tensor Core / mma.sync 단위

(C) Thread tile ( TM x TN )

  • 예 : 2 x 2, 4 x 4, 4 x 8
  • warp tile 안에서 각 thread 가 담당하는 가장 작은 타일
  • thread 는 이 조각을 레지스터에 들고 K 루프 동안 누산한다.
BLOCK TILE 128x128
 ├── warp0 tile 64x64
 │     ├── thread0 tile 4x4
 │     ├── thread1 tile 4x4
 │     └── ...
 ├── warp1 tile 64x64
 └── ...

 

누산 방식에 대해 이해는 했음, 이걸 이제 GPU 의 각 계층에 대입해서 이해해보자