본문 바로가기

Memory-Centric IR for AICF

Mathematical Properties Behind FlashAttention - Streaming Transformations for Memory-Efficient Computation

1. Intro duction

최근 딥러닝 모델의 확장은 GPU 의 연산 능력보다 메모리 트래픽, Memory Traffic 을 더 큰 병목으로 만들고 있음

특히 Transformer 구조에서 Attention 연산은 다음과 같은 계산 형태를 갖는다.

Attention(Q, K, V) = Softmax(QK^T)V

이 연산을 직접 계산하면 중간 행렬

S = QK^T

의 크기가

O(N^2)

이 되며, 이는 대규모 메모리 접근을 발생시킨다.

 

FlashAttention 은 이 문제를 해결하기 위해 연산의 수학적 성질을 이용해 계산 구조를 재구성하였다. 

핵심 아이디어

어떤 연산이 전역계산을 요구하는 것처럼 보이지만, 실제로는 부분 상태를 이용한 스트리밍 계산으로 변환할 수 있는가?

 

FlashAttention 은 이러한 변환이 가능함을 보여준 대표적인 사례

FlashAttention 을 가능하게 하는 수학적 성질들을 분석하는 것을 목표로 한다.

 

2. Global Reduction vs Streaming Reduction

많은 딥러닝 연산은 다음과 같은 형태의 전역 Reduction 을 포함한다.

이러한 연산은 일반적으로 모든 데이터를 한 번에 접근해야 계산 가능하다.

그러나 특정한 연산의 경우에는 다음과 같은 Streaming 형태의 계산이 가능하다.

state = initial

for block in data:
    state = update(state, block)

이러한 계산 구조는 다음 조건을 만족해야 한다.

  • 상태가 유한한 크기로 유지될 것
  • 새로운 데이터가 들어올 때 기존 상태를 업데이트할 수 있을 것

이러한 구조를 Streaming Reduction 이라고 부른다. 

 

3. Online Normalization

Softmax 는 직관적으로 보면 전체 데이터를 모두 알아야 계산할 수 있는 연산처럼 보인다.

그러나 FlashAttention 은 Softmax 가 Streaming 방식으로 계산 가능함을 이용한다.

 

3.1 Numerical Stability Trick

Softmax 는 overflow 를 방지하기 위한 수치적 안정성 기법이 적용됨

 

3.2 Online Softmax

데이터를 블록 단위로 처리한다고 가정

각 상태의 정의

m_old : 현재까지의 max
l_old : 지수 합

새로운 블록 B 가 들어오면

max 탐색과 기존 합의 재조정 수행

이를 통해 Softmax 를 전체 데이터를 저장하지 않고 계산할 수 있게 만든다.

 

핵심 수학적 성질

이 알고리즘이 가능한 이유는 다음 성질 때문

e^(a+b) = e^a x e^b

즉 지수 함수의 곱셈 구조 덕분에 기존 값을 재조정할 수 있다.

 

4. Streaming Weighted Sum

Attention 의 마지막 연산은 단순한 가중 합 구조이다.

이는 결합 법칙, Associativity 으로 분해 가능

이 때문에 계산을 여러 블록으로 나누어 수행할 수 있다.

 

5. Re-materialization: Compute vs Memory Trade-off

중간 행렬을 저장하지 않는다는 점이 다르다. ( recompute QK^T )

계산을 다시 하는 것이 메모리 접근보다 더 빠르다는 성질이 존재

 

6. Tiling-Compatible Computation

모든 계산을 Tile 단위로 수행한다.

이를 통해 shared memory 를 활용할 수 있게 한다.

 

7. Stateful Streaming Computation

FlashAttention 알고리즘은 다음과 같은 구조를 갖는다.

for each query tile:

    initialize state

    for each key/value tile:

        compute partial attention

        update normalization state

        accumulate output

여기서 유지되는 상태는 다음과 같다

  • running max
  • exponential sum
  • partial output

이러한 구조는 Stateful Streaming Computation 의 대표적인 예이다.

 

8. Generalization

FlashAttention 에 사용된 수학적 구조는 Attention 에만 적용되는 것이 아니다.

다음과 같은 연산들도 비슷한 구조를 갖는다.

  • softmax 기반 routing
  • Mixture-of_experts gating
  • LayerNorm
  • GroupNorm

이들은 모두 Streaming Reduction 구조로 변환 가능

 

9. Conclusion

핵심은 새로운 커널을 만드는 것이 아닌 연산의 수학적 성질을 이용해 계산 구조를 재구성한 것이다.

이를 통해 메모리 중심 계산 모델 구현 가능