본문 바로가기

GPU-KERNEL

이상적인 Epilogue Kernel 구현 - FlashAttention 을 보며...

FlashAttention 의 구조는 GEMM - Epilogue 패턴을, Attention 전체까지 확장한 형태

 

Flash Attention = Attention 전체를 하나의 거대한 Epilogue 로 재정의한 구조

  • 메인 GEMM 은 오래 걸리는 대규모 연산
  • Epilogue 는 그 이후에 C 에 추가 작업하는 fused 연산

FlashAttention 은 이 패턴을 극단적으로 확장한다.

 

1) FlashAttention 관점에서의 Attention 단계 해부

일반 Attention 흐름

  • QK - score ( GEMM )
  • score - softmax ( non-GEMM )
  • softmax(score) @ V ( GEMM )

즉, 한 번의 Attention 을 두 개의 GEMM + 하나의 거대한 중간행렬로 구성한다.

QK^T 계산의 결과를 C 로 메모리에 store 하지 않고, 곧바로 softmax 계산과 PV 계산까지 이어붙여 버리는 Epilogue

 

2) FlashAttention 의 Tile Pipeline 구조 

메인 GEMM 의 결과를 받아서, 그 위에 후처리를 계속 덧붙이는 방식이기 때문에 확장된 epilogue 라고 표현 할 수 있음

 

3) Epilogue 개념에 부합하는 이유

중간 결과를 global memory 에 write 하지 않는다.

epilogue 의 가장 큰 목적 중 하나가 DRAM round-trip 제거

FlashAttention 도 같은 목표

 

계산이결과값 C 바로 옆에서 수행된다 

C 값이 레지스터에 있거나 shared memory 에서 살아있는 동안 바로 post-processing 을 수행한다.

 

kernel fusion 을 통해 latency 를 줄인다.

kernel 수 줄이기 - launch overhead 감소 + 메모리 병목 감소

분리된 커널들의 단일 attention fused kernel로 합침

 

epilogue friendly tile pipeline

동일한 tile pipeline 의 사