본문 바로가기

GPU-KERNEL

Softmax Micro kernel Design

1. 커널 큰 그림

1) 런치 구성

dim3 block(WARP_SIZE, 1, 1);
dim3 grid(num_batches, 1, 1);
  • 블록당 1 warp 만 사용
  • grid.x = num_batches 이라서, batch 하나가 곧 blockIdx = 1 과 1:1 매핑
  • 각 block 은 16 16 tile 하나에 대해 QK - softmax - PV 전체를 수행

 

2) 데이터 레이아웃

  • K
  • V
  • O

한 batch 에서 하는 일

  1.  S = QK 
    1. WMMA 수행, 결과 저장
  2. P = softmax ( S * scale(
    1. s_scores -> s_probs
    2. 위 부분이 보고 싶은  softmax micro kernel
  3. O = P * V
    1. softmax 결과를 half 로 캐스팅해서 s_probs_h 에 저장
    2. 다시 WMMA 수행, O 에 저장

즉, Tensor Core 두 번 사이에 끼어 있는 warp softmax scalar reduction 이 병목 후보

 

2. 간단한 WMMA 부분, (QK, PV)

wmma::fragment<matrix_a, 16,16,16, __half, row_major> a_frag;
wmma::fragment<matrix_b, 16,16,16, __half, row_major> b_frag;
wmma::fragment<accumulator, 16,16,16, float>          c_frag;

wmma::fill_fragment(c_frag, 0.0f);
wmma::load_matrix_sync(a_frag, Q_b, Kdim); // ld = 16
wmma::load_matrix_sync(b_frag, K_b, N);    // ld = 16

wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
wmma::store_matrix_sync(s_scores, c_frag, N, mem_row_major);


///////////////////////////////////////////////////////////////

wmma::load_matrix_sync(a_frag, s_probs_h, N); // P
wmma::load_matrix_sync(b_frag, V_b, Dv);      // V
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
wmma::store_matrix_sync(O_b, c_frag, Dv, mem_row_major);

 

 

3. softmax warp 마이크로 커널 부분

__device__ void softmax_warp_16x16(float* s_scores, float* s_probs, float scale) {
    int lane = threadIdx.x & (WARP_SIZE - 1);

    for (int row = 0; row < M; ++row) {
        float x = -INFINITY;

        if (lane < N) {
            float v = s_scores[row * N + lane];
            x = v * scale;
        }

        // row-wise max
        float maxv = warp_allreduce_max(x);

        // exp + sum
        float ex = 0.0f;
        if (lane < N) {
            ex = __expf(x - maxv);
        }
        float sumv = warp_allreduce_sum(ex);

        // normalize
        if (lane < N) {
            float p = ex / (sumv + 1e-6f);
            s_probs[row * N + lane] = p;
        }
    }
}

(1) 스레드 역할 매핑

  • block = warp 1 개 = 32 threads
  • lane = threadIdx.x % 32
  • N = 16 이라서, 0~15 lane 만 실제 열 1개 담당
    • lane = col index
    • 각 row 마다 16 개 열을 16 개의 lane 이 동시에 처리
  • 16~31 lane 은 softmax 에서 계산에 직접 쓰이지 않는 dummy

 

(2) 메모리 접근 패턴 ( 공유 메모리 관점 )

  • s_scores[ row * N + lane ], lane = 0...15
  • s_score / s_probs 는 row-major 16, 16 float 배열
  • warp 의 절반이 연속된 16 개의 float 를 동시에 읽음 
    • shared memory bank 관점에서
      • float = 4 bytes
      • bank 는 32 개, address / 4 mod 32 rlwns
      • 16 개의 연속된 address 를 16 thread 가 읽으면, 서로 다른 bank 16 개에 매핑, bank conflict 없음
  • write 도 동일 패턴

위 구조 덕에 softmax 마이크로 커널은 모든 softmax 연산을 shared memory 에서, bank conflict 없이 수행

 

(3) row-wise max 단계

한 row 에 대해

  1. lane < 16 인 thread 만 x = s_scores[row * N + lane] * scale;
  2. warp_allreduce_max(x) 로 warp 전체에서 max 계산
  • 5 번의 __shfl_xor_sync + fmaxf
  • 마지막에 모든 lane 이 같은 maxv 값을 가지게 됨

의미상

maxv  = max_j (S[row, j] * scale) 을 warp 1개가 책임지고 계산

 

(4) exp + sum 단계

  • lane < 16 만 exp 계산
    • ex = __expf(x - maxv);
    • __expf 사용 - 빠른 근사 exp
    • x - maxv 로 수치 안정성 확보
  • warp_allreduce_sum(ex) 로 전체 합
    • warp-wide all-to-all reduction 으로 sumv 가 모든 lane 에 복제

 

(5) normalize 단계

float p = ex / (sumv + 1e-6f);
s_probs[row * N + lane] = p;
  • 최종적으로 s_probs 에 row-wise softmax 결과 저장

 

(6) 루프 구조

for (int row = 0; row < M; ++row)
  • warp 하나가 row 16 개를 순차적 처리
  • 한 row 당 
    • shared load 1
    • max reduction 1
    • exp + sum reduction 1
    • shared store 1

합치면 softmax 전체는 

  • shared read : 256 elements
  • shared write : 256 elements
  • warp reduction (5 단계 * 2번 * 16 rows = 160 shuffle + 160 fmax/fadd 수준)
  • expf 256 개

 

5. 이 softmax 마이크로커널의 특징/장단점

  • 완전히 warp-local
    • block 내 sync 는 row 루프 바깥에서만 필요, softmax 내부는 warp 단독으로 동작
    • warp 내에서 __shfl_* 만 쓰니 latency 도 비교적 낮음
  • shared memory 만 사용, global memory 접근 없음
    • QK, PV 와 완전히 decoupled
      • QK 결과를 shared 에 고정해두고,
      • softmax
      • 다시 shared - WMMA
    • 디버깅/검증이 쉬운 형태
  • N = 16 의 구조 
  • bank conflict free

단점

  • Tensor Core vs scaler
    • QK, PV 에서는 빠른 시간 연산
    • softmax 의 경우 적은 연산이지만 
      • Tensor Core throughput 에 비교
      • 상대적으로 scalar reduction + exp 가 느려보이는 구조
  • 절반의 warp 만 사용
  • row loop 순차 처리
  • shared memory 왕복

 

6. 개선 아이디어

  1. online softmax 로 통합
  2. warp 2개 사용, row 병렬화
  3. 레지스터 활용 극대화