핵심 아이디어는 전역 메모리의 원본 레이아웃의 변경 없이, SMEM 에 필요한 방향으로 재배치해서 타일의 생성
transpose 는 저장소 변형이 아닌, 타일 로딩 / 배치의 한 부분
1) 전제 : GEMM 의 계산 요구 형태
각 출력 타일, C[m, n] 을 만드려면
- A 에서 한 타일
- B 에서 한 타일
두 타일이 필요, 대부분의 커널은 A / B 타일을 SMEM 에 올려녾고, 그 타일을 반복 reuse 하면서 FMA 를 돌린다.
2) forward : C = A(M K) B(K N)
전역 메모리 로드 패턴
- A : row-major 에서 A[m, k] 는 k 가 연속
- B : row-major 에서 B[k, n] 는 n 이 연속
SMEM 배치
- smemA 는 M K 형태로 행 단위 연속이 좋음
- smemB 는 K N 형태로 열 (혹은 n 축) 연속이 좋음
3) backward1 : dW = X^T dY ( A^T B )
수식상 표현은 A^T 지만, 전역 메모리의 X 는 여전히 row-major 로 저장되어 있음
목표 타일 형태
- X^T 에서 필요한 타일은 K M 처럼 보이지만,
- 실제로는 dW 타일이 K N 이니까
- A 쪽은 K 차원 기준으로 묶인 타일
- B 쪽은 dY 의 K 차원 묶인 타일
중요한 포인트
X 를 전역에서 진짜 transpose 해서 만들지 않고 커널은 이와 같이 한다.
A) 전역에서 읽을 때 인덱싱만 바꿔서 SMEM 에 K-major로 쌓기
- 전역: X[m, k] 를 읽되,
- SMEM 에 쓸 때 : smemA[k, m] 형태로 저장 ( SMEM 에서 transpose )
이러면
- 원본 X 는 그대로
- SMEM 에서만 X 타일이 준비됨
- 이후 compute 는 동일한 방식
이것이 SMEM 올릴 때 알고리즘 변경의 형태
B) compute 단계에서 transpose 처럼 접근
- SMEM 저장은 그대로 두고
- compute 에서 smemA[k, m] 로 읽기
하지만 이건
- warp 가 A 타일을 읽는 패턴이 깨지기 쉬움
- bank conflict 의 증가 위험
- vectorized load 의 이점 감소
실전에선, A 의 SMEM 에 올릴 때 형태를 맞춰두는 쪽이 더 흔함
4) backward2 : dX = dY W^T ( A B^T )
이 경우는 B 가 transpose, 동일하게 처리
5) 원본 변형 없이 실행의 의미
- 전역 메모리는 재배치하지 않음
- 전역에서 coalesced load 가 가능한 방식으로 데이터 획득
- SMEM 에 저장할 때 인덱스를 바꿔서 타일을 원하는 레이아웃으로 만듦
- compute 단계는 해당 layout 에 맞춘 고정된 루프 / WMMA / SIMIT 구조
6) 구현 방식
(1) GEMM 커널을 하나로 유지하려면
- transA / transB 를 attrs 로 받고
- load-to-smem 단계만 분기하는 방식
(2) dispatch 관점
- TT / TN / NT / NN 을 각기 다른 KernelVariant 로 두되
- 공통 코드를 최대한 공유
7) 고려해야 할 점
- 전역 로드는coalesced 인데 SMEM store 가 scatter 가 될 수 있음
- 로드는 coalesced + SMEM 에는 transpose store 를 하더라도 store 는 SMEM 이니까 비용이 덜하고, 대신 bank conflict 만 고려
- bank conflict
- transpose store 시 padding 이 필요할 때가 많음