본문 바로가기

AI Compiler framework

naive GEMM 을 TC GEMM 으로 수정하자 ( NN, TN, NT 지원 - T 는 transpose 야 )

목표

학습용 backward 확장 가능한 TC GEMM 뼈대 - Ampere (sm_86) 기준

 

0) 최소 TC GEMM 의 스펙

  • 입력 : half(A), half(B)
  • 누적 : float(accum)
  • 출력 : float, half
  • epilogue : 일단 나중에
  • 지원 : NN / NT / TN ( 이것이 backward gemm 지원 )
  •  block : warp 단위로 시작

성능이 아닌 구조 / 확장성이 목적

 

1) 가장 중요한 설계 : transpose 는 SMEM 에서 만든다

전역 메모리 원본은 그대로 두고

  • 전역에서 읽을 때는 가능한 coalesced 유지
  • SMEM 에 저장할 때 레이아웃을 바꿔서 transpose store
  • compute 는 항상 WMMA 가 좋아하는 타일 레이아웃을 보게 만든다

 

2) WMMA 가 요구하는 레이아웃 ( Ampere, row_major / col_major )

WMMA 는 fragment 를 만들 때 이러한 조합을 쓴다

  • A fragment : matrix_a
  • B fragment : matrix_b

핵심은 global 이 row-major 든 뭐든 상관없이, SMEM 에 어떤 레이아웃으로 담을지를 결정

가장 단순한 선택

  • A tile 을 row_major 로 SMEM 에 저장
  • B tile 을 col_major 로 SMEM 에 저장

 

3) NN / TN / NT 에서 SMEM 에 어떻게 담나 ( 핵심 )

공통 목표

  • WMMA compute 는 동일하게
    • load_matrix_sync(a_frag, smemA_ptr, ldA)
    • load_matrix_sync(b_frag, smemB_ptr, ldB)
    • mma_sync(acc, a_frag, b_frag, acc)

case NN

  • smem A : A 를 그대로 row_major 로 
  • smem B : B 를 col_major 로 담기
    • global B[k, n] 를 읽어서 smemB[n, k] 형태로 저장 ( 실질적 transpose store )

case TN

  • 수식상 A^T 이 필요 - global A 는 A[m, k]로 그대로 있음
  • smemA : global A 에서 읽은 것을 transpose store 해서 row_major 타일로 만들기
    • 필요한 것은 A^T 의 K M 타일
    • 즉 global A [m, k] 를 읽어 smemA[k, m] 형태로 저장
  • smemB : NN 과 동일하게 B 를 col_major 로 저장

case NT

  • smemB : B^T 의 사용, global B 를 row_major 그대로 담아도 되지만 WMMA 가 기대하는 형태로 맞추는 것이 최선?
    • B^T 가 (N x K)
    • WMMA 에서 b_frag 이 col_major 를 받는다고 하면
    • smemB 는 결국 K x N col_major 가 필요
    • global B 를 읽을 때 인덱싱을 바꿔서 B^T 타일을 col_major 로 담으면 된다. 

 

5) 1-warp / block WMMA GEMM 뼈대

처음엔 block 이 1 warp 가 16 x 16 C 타일 하나를 만들게 함

  • grid : ceil ( M / 16 ) x ceil ( N / 16 )
  • 각 block 은 (tile_m, tile_n) 을 담당
  • K 루프(16)를 돌며
    • A/B 타일을 SMEM 에 채움
    • WMMA load
    • mma_sync accumulate

마지막에 acc 를 global C 에 store

 

더보기

global A, B 를 row_major / col_major 로 담는다는 표현은 물리적 메모리 레이아웃이 아닌, 인덱싱 규약과 논리적 행렬 의미 기준의 말

 

Global memory 에서의 row_major / col_major

CUDA 에서 global memory 는 그냥 1차원 주소 공간, row_major, col_major 는 하드웨어 개념이 아님

// row_major
A[i][j] -> A[i * lda + j]

// col_major
A[i][j] -> A[j * lda + i]
  • 같은 메모리
  • 다른 인덱싱 해석

 

Shared memory 로 들어오면

row, col_major 가 아닌 bank mapping 이 전부...

여기서 말하는 row, col_major 로 SMEM 에 저장, 에 대한 의미는 단순 인덱싱이라고 생각해도 됨