본문 바로가기

dev_AI_framework

Scaled Dot-Product Attention (SDPA)

개요

CUDA 기반의 Scaled Dot-Product Attention (FWD) 제공, 

내부에서 Q @ K^T - softmax - dropout - @ V 순서로 계산

 

데이터 형식 & 제약

  • 텐서 타입: float32 (DType::F32)
  • 레이아웃: RowMajor
  • 장치: CUDA
  • 차원:
    • Q: [B, H, M, D]
    • K: [B, H, N, D]
    • V: [B, H, N, D]
    • Y: [B, H, M, D]

내부 처리

  • 임시 버퍼 S=[B,H,M,N], P=[B,H,M,N]를 할당하여 사용
  • K^T는 gemm_run의 trans_b 제약을 우회하기 위해 슬라이스마다 전치 커널연속 메모리에 만들어서 사용
  • softmax는 행(row) 단위로 계산
  • dropout은 확률 p에 따라 P에 적용(옵션)

 

속성(Attrs)

ai::SDPAAttrs:

  • float scale
    • 0이면 자동으로 1/sqrt(D) 사용
  • bool causal
    • (향후 지원) 현재 구현에서는 논리 경로만 전달되고 계산에는 미반영
  • float dropout_p
    • 0.0이면 dropout 스킵
  • bool scale_in_train
    • dropout 시 1/(1-p) 스케일 여부
  • uint64_t seed
    • dropout RNG seed

C++ 런처 (CUDA Backend)

파일: backends/cuda/ops/sdpa/launcher.cu

  • FWD: ai::Status SDPACudaLaunch(...)
    • Q,K,V,Y 형식 검증 → K 전치 → gemm_run(Q @ Kt) → softmax(scale) → dropout(옵션) → gemm_run(P @ V)
    • 실패 시 Status::RuntimeError/Status::ShapeMismatch 등 반환
  • BWD: ai::Status SDPACudaBackwardLaunch(...)
    • 스텁: Status::Unimplemented 반환

참고: gemm_run이 trans_b=true를 지원하지 않는 환경을 위해 transpose_rm_f32 커널로 K^T를 만들고 trans_b=false로 호출합니다.


C++ 디스패치(얇은 엔트리)

파일: src/ops/sdpa.cpp

  • int ai::ops::sdpa_run(...)
    • 형식 검증 후 SDPACudaLaunch 호출
    • 반환 매핑:
      • Status::Ok → 0
      • Status::Unimplemented → -8
      • 그 외 실패 → -7

 

향후 “GEMM → softmax → (dropout) → GEMM”을 개별 커널로 돌리는 대신, epilogue/fusion로 묶어서 메모리 왕복과 커널 런치 수를 크게 줄일 수 있다.