목표
학습용 backward 확장 가능한 TC GEMM 뼈대 - Ampere (sm_86) 기준
0) 최소 TC GEMM 의 스펙
- 입력 : half(A), half(B)
- 누적 : float(accum)
- 출력 : float, half
- epilogue : 일단 나중에
- 지원 : NN / NT / TN ( 이것이 backward gemm 지원 )
- block : warp 단위로 시작
성능이 아닌 구조 / 확장성이 목적
1) 가장 중요한 설계 : transpose 는 SMEM 에서 만든다
전역 메모리 원본은 그대로 두고
- 전역에서 읽을 때는 가능한 coalesced 유지
- SMEM 에 저장할 때 레이아웃을 바꿔서 transpose store
- compute 는 항상 WMMA 가 좋아하는 타일 레이아웃을 보게 만든다
2) WMMA 가 요구하는 레이아웃 ( Ampere, row_major / col_major )
WMMA 는 fragment 를 만들 때 이러한 조합을 쓴다
- A fragment : matrix_a
- B fragment : matrix_b
핵심은 global 이 row-major 든 뭐든 상관없이, SMEM 에 어떤 레이아웃으로 담을지를 결정
가장 단순한 선택
- A tile 을 row_major 로 SMEM 에 저장
- B tile 을 col_major 로 SMEM 에 저장
3) NN / TN / NT 에서 SMEM 에 어떻게 담나 ( 핵심 )
공통 목표
- WMMA compute 는 동일하게
- load_matrix_sync(a_frag, smemA_ptr, ldA)
- load_matrix_sync(b_frag, smemB_ptr, ldB)
- mma_sync(acc, a_frag, b_frag, acc)
case NN
- smem A : A 를 그대로 row_major 로
- smem B : B 를 col_major 로 담기
- global B[k, n] 를 읽어서 smemB[n, k] 형태로 저장 ( 실질적 transpose store )
case TN
- 수식상 A^T 이 필요 - global A 는 A[m, k]로 그대로 있음
- smemA : global A 에서 읽은 것을 transpose store 해서 row_major 타일로 만들기
- 필요한 것은 A^T 의 K M 타일
- 즉 global A [m, k] 를 읽어 smemA[k, m] 형태로 저장
- smemB : NN 과 동일하게 B 를 col_major 로 저장
case NT
- smemB : B^T 의 사용, global B 를 row_major 그대로 담아도 되지만 WMMA 가 기대하는 형태로 맞추는 것이 최선?
- B^T 가 (N x K)
- WMMA 에서 b_frag 이 col_major 를 받는다고 하면
- smemB 는 결국 K x N col_major 가 필요
- global B 를 읽을 때 인덱싱을 바꿔서 B^T 타일을 col_major 로 담으면 된다.
5) 1-warp / block WMMA GEMM 뼈대
처음엔 block 이 1 warp 가 16 x 16 C 타일 하나를 만들게 함
- grid : ceil ( M / 16 ) x ceil ( N / 16 )
- 각 block 은 (tile_m, tile_n) 을 담당
- K 루프(16)를 돌며
- A/B 타일을 SMEM 에 채움
- WMMA load
- mma_sync accumulate
마지막에 acc 를 global C 에 store
더보기
global A, B 를 row_major / col_major 로 담는다는 표현은 물리적 메모리 레이아웃이 아닌, 인덱싱 규약과 논리적 행렬 의미 기준의 말
Global memory 에서의 row_major / col_major
CUDA 에서 global memory 는 그냥 1차원 주소 공간, row_major, col_major 는 하드웨어 개념이 아님
// row_major
A[i][j] -> A[i * lda + j]
// col_major
A[i][j] -> A[j * lda + i]
- 같은 메모리
- 다른 인덱싱 해석
Shared memory 로 들어오면
row, col_major 가 아닌 bank mapping 이 전부...
여기서 말하는 row, col_major 로 SMEM 에 저장, 에 대한 의미는 단순 인덱싱이라고 생각해도 됨