본문 바로가기

Memory-Centric IR for AICF

Streaming Weighted Reduction - FlashAttention generalization

Method Overview

Softmax-weighted reduction 은 다음 구조를 가진다.

softmax(QK^T)V

문제는 일반 구현이 다음을 materialize 한다는 것이다.

score matrix
probability matrix

이는 매우 큰 HBM traffic 을 발생시킨다.

Streaming weighted reductino 은 이를 tile streaming 방식으로 계산한다.

 

Mathematical Structure

핵심 수학적 성질은 rescaling invariant 이다.

Online softmax

m = max(x)
l = sum(exp(x-m))

새로운 block 이 들어올 때

m_new = max(m_old, m_block)
l_new = exp(m_old- m_new)*l_old + exp(m_block - m_new)*l_block

이 방식으로 이전 결과를 재정규화하며 누적할 수 있다.

 

Hardware Implication

Streaming attention kernel 은

Q,K,V tile load
→ score compute
→ normalization
→ weighted accumulation

을 하나의 kernel 에서 수행한다.

결과

score matrix 저장 제거
probability matrix 저장 제거

HBM traffic 이 크게 줄어든다.

 

Compiler Recognition (MCIR)

Property

weighted_streaming_reduction

Legality

rescaling inveriant
streaming merge 가능

Rewrite

materialized attention
-> streaming attention

Kernel mapping

tiled fused attention kernel