본문 바로가기

Memory-Centric IR for AICF

Attention MCIR Example - Full IR Walkthough in AICF

1. Overview

이 문서는 AICF 의 Memory-Centric IR 이 실제 실제 Attention computation 을 어떻게 표현하는지 단계적으로 설명한다.

목표는 다음을 보여주는 것이다.

Attention Semantic Graph
        ↓
AttentionRegion
        ↓
StreamingAttentionRegion
        ↓
Tile-based MCIR
        ↓
Kernel Structure

이를 통해 FlashAttention 과 유사한 execution structure 가 어떻게 compiler 에서 생성되는지 설명한다.

 

2. Initial Semantic Graph

Transformer Attention 은 다음 연산으로 구성된다.

Scores = Q × Kᵀ
Scores = Mask(Scores)
Scores = Softmax(Scores)
Output = Scores × V

이를 Semantic IR 로 표현하면 다음과 같다

Node0: MatMul(Q, Kᵀ)
Node1: Mask(Node0)
Node2: Softmax(Node1)
Node3: MatMul(Node2, V)

 Graph 구조

Q ----┐
      │
      ▼
    MatMul ----> Mask ----> Softmax ----> MatMul ----> Output
      ▲                                      ▲
      │                                      │
      K                                      V

이 구조는 수학적으로 정확하지만 실행 효율 측면에서는 비효울적이다.

문제

Score Matrix Materialization
Softmax Matrix Materialization

 두 개의 큰 intermediate tensor 가 생성된다.

 

3. Attention Region Detection

컴파일러는 Semantic IR 에서 Attention 패턴을탐지한다.

 Pattern

MatMul(Q, Kᵀ)
→ Mask (optional)
→ Softmax
→ MatMul(..., V)

패턴이 발견되면 다음 Region 이 생성된다.

Execution Region : Attention Region

구조

AttentionRegion
{
    inputs:
        Q
        K
        V

    outputs:
        O

    ops:
        score_matmul
        mask
        softmax
        value_matmul
}

이 단계에서는 아직 memory optimization 이 적용되지 않았다.

 

4. Streaming Transformation

다음 단계에서 Attention Region 은 Streaming Attention Region 으로 변환된다.

핵심 아이디어

Score Matrix 생성 제거

대신 다음 실행 구조를 사용한다.

for q_tile:

    load Q_tile

    for kv_tile:

        compute partial score
        update softmax
        accumulate output

이 구조는 intermediate tensor 를 생성하지 않는다.

 

5. Streaming Region Representation

Streaming Attention Region 은 다음과 같이 표현된다.

StreamingRegion
{
    stream_axis: sequence

    state:
        softmax_max
        softmax_sum
        accumulator
}

state 변수

softmax_max
softmax_sum
accumulator

이 변수들은 streaming softmax 알고리즘에서 사용된다.

 

6. TileRegion Construction

Streaming Region 내부에서 TileRegion 이 생성된다

TileRegion
{
    tile_shape:
        M: 128
        N: 64
        K: 64
}

.Execution structure

StreamingRegion

    for q_tile:

        TileRegion

            LoadTile(Q)

            for kv_tile:

                LoadTile(K)
                LoadTile(V)

                Compute(score)

                UpdateState(softmax)

                Compute(accumulate)

 

7. MCIR Graph

전체 MCIR graph 는 다음과 같다

ExecutionRegion (Attention)

    inputs:
        Q
        K
        V

    outputs:
        O

    StreamingRegion

        loop q_tile

            TileRegion

                Node0: LoadTile(Q)

                loop kv_tile

                    Node1: LoadTile(K)
                    Node2: LoadTile(V)

                    Node3: ComputeScore
                    Node4: UpdateSoftmax
                    Node5: AccumulateOutput

                Node6: StoreTile(O)

 

8. Memory Residency

MCIR 은 각 데이터의 memory location 을 명시한다.

Q_tile → shared
K_tile → shared
V_tile → shared

score → register

softmax_state → register

accumulator → register

 

 

Q_tile shared memory
K_tile shared memory
V_tile shared memory
score register
softmax state register
output accumulator register

이 구조는 HBM 접근을 최소화한다.

 

9. Materialization Analysis

Semantic IR 에서는 다음 tensor 가 존재한다.

Score Matrix
Softmax Matrix

하지만 MCIR 에서는 다음이 된다.

Score MAtrix - eliminated
Softmax Matrix - eliminated

Intermediate tensor 는 register state 로 대체된다.

 

10. Lowering to Kernel IR

MCIR 은 다음 Kernel structure 로 변횐된다.

for block_q_tile:

    load Q_tile to shared

    init softmax state

    for block_kv_tile:

        load K_tile
        load V_tile

        compute score

        update online softmax

        accumulate output

store output tile

이 구조는 FlashAttention kernel 과 동일한 실행 전략을 사용한다.

 

11. Resulting Kernel Properties

이 커널의 특징

Score matrix never materialized
Softmax computed online
Tile-based streaming execution
  • memory traffic 감소
  • kernel launch 감소
  • on-chip reuse 증가

 

12. Comparison

Semantic execution

MatMul kernel
Softmax kernel
MatMul kernel

MCIR Execution

Single streaming kernel

 

Kernel Count 3 1
Score Matrix materialized eliminated
Softmax full matrix streaming
Memory Traffic high low

 

13. Key Insight

이 예제에서 중요한 점은

FlashAttention 은 특별한 연산이 아니다.

MatMul
Softmax
MatMul

같은 연산이지만 실행 구조가 다르다

AICF 는 다음을 수행한다

semantic graph
-> memory-centirc execution

즉 FlashAttention 은 특수 커널이 아니라 컴파일 결과물이 된다.

 

14. Summary

이 문서는 AICF 의 Memory-Centric IR 이 Attention computation 을 어떻게 표현하는지 설명했다.

핵심 흐름

Semantic Attention Graph
        ↓
AttentionRegion
        ↓
StreamingAttentionRegion
        ↓
TileRegion
        ↓
MCIR Graph
        ↓
Kernel Structure

 이를 통해 AICF 는 semantic graph 로부터 FlashAttetnion-like execution 을 자동 생성할 수 있다.