본문 바로가기

Memory-Centric IR for AICF

AttentionRegion Transformation - Deriving FlashAttention via Memeory-Centric Compileation in AICF

1. Overview

Flashattention 은 일반적으로 특수 CUDA Kernel 로 제공된다

대부분의 딥러닝 시스템에서 FlashAttention 은 다음과 같이 사용된다. 

flash_attention(Q, K, V)

즉 Attention 계산을 라이브러리 커널로 교체하는 방식

하지만 이 방식은 다음과 같은한계를 가진다. 

  • 특정 Attention 형태에만 적용 가능
  • 컴파일러가 내부 구조를 이해하지 못함
  • 다른 attention 구조로 일반화하기 어려움

AICF 는 다른 접근을 취한다

Flash Attention 을 커널로 제공하는 대신 Semantic Attention Graph 로부터 FlashAttention과 유사한 실행 구조를 컴파일 과정에서 유도한다.

이를 위해 AICF 는 Attention Region Transformation 을 도입한다.

 

2. Transformation Pipeline

AttentionRegion transformation 은 다음 단계로구성된다.

Semantic IR
      │
      ▼
Attention Pattern Detection
      │
      ▼
AttentionRegion
      │
      ▼
StreamingAttentionRegion
      │
      ▼
Kernel IR
      │
      ▼
FlashAttention-like Kernel

  각 단계는execution structure 를 점진적으로 구체화한다.

 

3. Attention Pattern Detection

Sementic IR 에서 다음 패턴을 탐지한다

Scores = MatMul(Q, Kᵀ)
Scores = Mask(Scores)     (optional)
Scores = Softmax(Scores)
Output = MatMul(Scores, V)

.이 패턴은 Transformer attention 구조의 핵심이다.

패턴이 발견되면 다음과 같은 region 이 생성된다.

AttentionRegion {
    Q
    K
    V
}

이 region 은 semantic 의미를 유지한 채 하나의 logical computation unit 으로 취급된다.

 

4. Attention Region Representation

Attention Region 은 다음 정보를 포함한다.

AttentionRegion
{
    inputs:
        Q
        K
        V

    operations:
        score_matmul
        mask (optional)
        softmax
        value_matmul

    axes:
        sequence_length
        head_dim
}

여전히 semantic structure 가 유지된다.

아직 FlashAttention 구조가 아님

 

5. Streaming Transformation

Attention Region 은 다음 단계에서 Streaming Attention Region 으로 변환된다.

핵심 아이디는 다음과 같다

Attention score matric 를 생성하지 않는다.

즉 다음 intermediate tensor 를 제거한다.

Score Matrix
Softmax Probability Matrix

대신 streaming computation 을 사용한다.

 

6. Streaming Attention Region

Streaming Attentin Region 은 다음 구조를 가진다. 

StreamingAttentionRegion
{
    tile(Q)

    for kv_tile:
        load(K_tile)
        load(V_tile)

        compute partial_score
        update online_softmax
        accumulate output
}

이 구조는 다음 특징을 가진다.

Tile-based execution

연산은 전체 tensor 가 아니라 tile 단위로 수행된다.

Streaming K/V tiles

K 와 V 는 streaming 방식으로 처리된다

Online Softmax

Softmax 는 다음과 같은 online algorithm 으로 계산된다.

m_new = max(m_old, score_tile)
l_new = exp(m_old - m_new)*l_old + sum(exp(score_tile - m_new))

이 방식은 전체 score matrix 없이 softmax 를 계산하 ㄹ수 있다.

 

7. Tile Lifecycle

Streaming Attention Region 에서는 tile lifecycle 이 중요하다.

HBM
 │
 ▼
Shared Memory
 │
 ▼
Registers
 │
 ▼
Accumulation
 │
 ▼
Writeback

각 tile 은 다음 단계를 거친다.

  1. Q tile load
  2. K/V tile streaming
  3. score compute
  4. softmax update
  5. output accumulation

이 lifecycle 은 Flashattention kernel 의 핵심 구조이다.

 

8. Kernel Lowering

Streaming Attention Region 은 Kernel IR 로 lowering 된다.

예시 구조는 다음과 같다

for q_tile in sequence:

    load Q_tile

    init softmax state

    for kv_tile in sequence:

        load K_tile
        load V_tile

        compute score

        update softmax

        accumulate output

store output

이 kernel 은 다음 특징을 가진다.

  • score matrix never materialized
  • softmax computed online
  • memeory traffic minimized

이 구조는 Flash Attention kernl 과 동일한 execution strategy 를 가진다.

 

 9. Key Compiler Responsibilities

Attention Region transformation 에서 컴파일러는 다음 결정을 내려야 한다.

Tile size

tile_m
tile_n
tile_k

Memory residency

어떤 데이터가 다음에 위치할 것인가.

register
shared memory
global memory

Streaming order

Q-major
KV-major

Parallel mapping

warp
block
thread

 

10. Comparison with Library Approach

기존 방식

Attention Graph
      ↓
Call flash_attention kernel

AICF 방식

Attention Graph
      ↓
AttentionRegion detection
      ↓
Streaming transformation
      ↓
Kernel generation

차이는 다음과 같다

FlashAttention fixed kernel compiler generated
Execution structure hidden explicit
Optimization manual systematic
Extensibility limited high

 

11. Why This Matters

이 접근은 단순한 kernel replacement 가 아님

AICF 는 다음을 가능하게 한다

Execution structure synthesis

execution pipeline 을 컴파일러가 생성한다.

Architecture specialization

GPU architecture 에 따라 다른 streaming schedule 을 생성할 수 있다.

Generalization

FlashAttention 개념을 다른 연산으로 확장할 수 있다.