1. 커널 큰 그림
1) 런치 구성
dim3 block(WARP_SIZE, 1, 1);
dim3 grid(num_batches, 1, 1);
- 블록당 1 warp 만 사용
- grid.x = num_batches 이라서, batch 하나가 곧 blockIdx = 1 과 1:1 매핑
- 각 block 은 16 16 tile 하나에 대해 QK - softmax - PV 전체를 수행
2) 데이터 레이아웃
- Q
- K
- V
- O
한 batch 에서 하는 일
- S = QK
- WMMA 수행, 결과 저장
- P = softmax ( S * scale(
- s_scores -> s_probs
- 위 부분이 보고 싶은 softmax micro kernel
- O = P * V
- softmax 결과를 half 로 캐스팅해서 s_probs_h 에 저장
- 다시 WMMA 수행, O 에 저장
즉, Tensor Core 두 번 사이에 끼어 있는 warp softmax scalar reduction 이 병목 후보
2. 간단한 WMMA 부분, (QK, PV)
wmma::fragment<matrix_a, 16,16,16, __half, row_major> a_frag;
wmma::fragment<matrix_b, 16,16,16, __half, row_major> b_frag;
wmma::fragment<accumulator, 16,16,16, float> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
wmma::load_matrix_sync(a_frag, Q_b, Kdim); // ld = 16
wmma::load_matrix_sync(b_frag, K_b, N); // ld = 16
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
wmma::store_matrix_sync(s_scores, c_frag, N, mem_row_major);
///////////////////////////////////////////////////////////////
wmma::load_matrix_sync(a_frag, s_probs_h, N); // P
wmma::load_matrix_sync(b_frag, V_b, Dv); // V
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
wmma::store_matrix_sync(O_b, c_frag, Dv, mem_row_major);
3. softmax warp 마이크로 커널 부분
__device__ void softmax_warp_16x16(float* s_scores, float* s_probs, float scale) {
int lane = threadIdx.x & (WARP_SIZE - 1);
for (int row = 0; row < M; ++row) {
float x = -INFINITY;
if (lane < N) {
float v = s_scores[row * N + lane];
x = v * scale;
}
// row-wise max
float maxv = warp_allreduce_max(x);
// exp + sum
float ex = 0.0f;
if (lane < N) {
ex = __expf(x - maxv);
}
float sumv = warp_allreduce_sum(ex);
// normalize
if (lane < N) {
float p = ex / (sumv + 1e-6f);
s_probs[row * N + lane] = p;
}
}
}
(1) 스레드 역할 매핑
- block = warp 1 개 = 32 threads
- lane = threadIdx.x % 32
- N = 16 이라서, 0~15 lane 만 실제 열 1개 담당
- lane = col index
- 각 row 마다 16 개 열을 16 개의 lane 이 동시에 처리
- 16~31 lane 은 softmax 에서 계산에 직접 쓰이지 않는 dummy
(2) 메모리 접근 패턴 ( 공유 메모리 관점 )
-
s_scores[ row * N + lane ], lane = 0...15 - s_score / s_probs 는 row-major 16, 16 float 배열
- warp 의 절반이 연속된 16 개의 float 를 동시에 읽음
- shared memory bank 관점에서
- float = 4 bytes
- bank 는 32 개, address / 4 mod 32 rlwns
- 16 개의 연속된 address 를 16 thread 가 읽으면, 서로 다른 bank 16 개에 매핑, bank conflict 없음
- shared memory bank 관점에서
- write 도 동일 패턴
위 구조 덕에 softmax 마이크로 커널은 모든 softmax 연산을 shared memory 에서, bank conflict 없이 수행
(3) row-wise max 단계
한 row 에 대해
- lane < 16 인 thread 만 x = s_scores[row * N + lane] * scale;
- warp_allreduce_max(x) 로 warp 전체에서 max 계산
- 5 번의 __shfl_xor_sync + fmaxf
- 마지막에 모든 lane 이 같은 maxv 값을 가지게 됨
의미상
maxv = max_j (S[row, j] * scale) 을 warp 1개가 책임지고 계산
(4) exp + sum 단계
- lane < 16 만 exp 계산
- ex = __expf(x - maxv);
- __expf 사용 - 빠른 근사 exp
- x - maxv 로 수치 안정성 확보
- warp_allreduce_sum(ex) 로 전체 합
- warp-wide all-to-all reduction 으로 sumv 가 모든 lane 에 복제
(5) normalize 단계
float p = ex / (sumv + 1e-6f);
s_probs[row * N + lane] = p;
- 최종적으로 s_probs 에 row-wise softmax 결과 저장
(6) 루프 구조
for (int row = 0; row < M; ++row)
- warp 하나가 row 16 개를 순차적 처리
- 한 row 당
- shared load 1
- max reduction 1
- exp + sum reduction 1
- shared store 1
합치면 softmax 전체는
- shared read : 256 elements
- shared write : 256 elements
- warp reduction (5 단계 * 2번 * 16 rows = 160 shuffle + 160 fmax/fadd 수준)
- expf 256 개
5. 이 softmax 마이크로커널의 특징/장단점
- 완전히 warp-local
- block 내 sync 는 row 루프 바깥에서만 필요, softmax 내부는 warp 단독으로 동작
- warp 내에서 __shfl_* 만 쓰니 latency 도 비교적 낮음
- shared memory 만 사용, global memory 접근 없음
- QK, PV 와 완전히 decoupled
- QK 결과를 shared 에 고정해두고,
- softmax
- 다시 shared - WMMA
- 디버깅/검증이 쉬운 형태
- QK, PV 와 완전히 decoupled
- N = 16 의 구조
- bank conflict free
단점
- Tensor Core vs scaler
- QK, PV 에서는 빠른 시간 연산
- softmax 의 경우 적은 연산이지만
- Tensor Core throughput 에 비교
- 상대적으로 scalar reduction + exp 가 느려보이는 구조
- 절반의 warp 만 사용
- row loop 순차 처리
- shared memory 왕복
6. 개선 아이디어
- online softmax 로 통합
- warp 2개 사용, row 병렬화
- 레지스터 활용 극대화
'GPU-KERNEL' 카테고리의 다른 글
| 여기서 다시 한 번 GPU 실행 단위 정리하기 (0) | 2025.12.11 |
|---|---|
| 각 warp 는 다른 일을 담당할 수 있다... (0) | 2025.12.10 |
| 이상적인 Epilogue Kernel 구현 - FlashAttention 을 보며... (1) | 2025.12.09 |
| Nsight Compute / ncu 권한 문제 해결 가이드 (Windows + GeForce) (0) | 2025.12.04 |
| Tensor Core 기반 GEMM ( 32, 32, 32 ) - Tile / MMA 구조 정리 문서 (0) | 2025.12.01 |