본문 바로가기

Memory-Centric IR for AICF

MCIR Minimal Implementation Design (Python)

1. Goal

이 문서의 목표는 AICF 에 Memory-Centric IR (MCIR) 를 도입하기 위한 최소 구현 단위를 정의하는 것이다.

FlashAttention-like execution structure 를 완전 구현하는 것이 아니라. Attention semantic graph 를 MCIR 로 재구성할 수 있음을 증명하는 것

 

2. Minimal Scope

우선 제한됨 범위의 구현

  • MCIR 기본 데이터 구조
  • MCIR region hierarchy
  • Attention pattern detection
  • AttentionRegion 생성
  • StreamingRegion 생성
  • 간단한 TileRegion 생성
  • MCIR pretty printer / dump

아직 안 할 것

  • 실제 autotuning
  • 복잡한 cost model
  • cp.async modeling
  • full kernel codegen
  • architecture-specific scheduling
  • backward pass
  • generic all-op support

목표

Semantic Attention Graph
    ↓
AttentionPatternPass
    ↓
AttentionRegion
    ↓
StreamingLoweringPass
    ↓
StreamingRegion + TileRegion
    ↓
MCIR Dump

 

3. Implementation Principles

3.1 Semantic IR 는 그대로 둔다.

기존 AICF semantic / operator graph 를 버리지 않는다. 

MCIR 은 그 위에 새로 생기는 중간 계층이다.

3.2 Python prototype 우선

처음에는 Python dataclass 기반으로 빠르게 구조를 고정한다.

3.3. IR dump 가 핵심 산출물

초기 단계에서는 실행보다 MCIR 이 잘 생성되었는지 눈으로 검증하는 것이 더 중요

3.4 Attention 하나만 먼저

Attention 전용 pass 를 먼저 만들고 일반화를 진행한다.

 

4. Proposed Directory Layout

AICF 내부에 최소한 구현

python/aicf_v2/src/aicf_v2/mcir/
├── __init__.py
├── types.py
├── values.py
├── nodes.py
├── regions.py
├── module.py
├── printer.py
├── builder.py
└── passes/
    ├── __init__.py
    ├── attention_pattern.py
    └── streaming_lowering.py

역할 정리

 

  • types.py
    • 공통 enum, type alias
  • values.py
    • MCValue 정의
  • nodes.py
    • MCNode 정의
  • regions.py
    • ExecutionRegion, StreamingRegion, TileRegion
  • module.py
    • MCIR 전체 컨테이너
  • printer.py
    • dump / pretty print
  • builder.py
    • region/node/value 생성 helper
  • passes/attention_pattern.py
    • semantic graph에서 attention pattern 탐지
  • passes/streaming_lowering.py
    • AttentionRegion → StreamingRegion 변환

 

 

5. Core Object Model

최소 구현에서는 객체 수를 적게 가져가는 것이 좋다.

5.1 MCModule

전체 MCIR 를 담는 루트 객체

from dataclasses import dataclass, field
from typing import List

@dataclass
class MCModule:
    regions: List["Region"] = field(default_factory=list)

역할

  • top-level region 보관
  • dump 진입점
  • 향후 pass 입력 / 출력 컨테이너

 

5.2 MCValue

데이터 객체

from dataclasses import dataclass
from typing import Optional, List

@dataclass
class MCValue:
    name: str
    shape: tuple[int, ...]
    dtype: str
    residency: str = "global"
    producer: Optional[str] = None
    consumers: List[str] = None

    def __post_init__(self):
        if self.consumers is None:
            self.consumers = []

초기 버전에서 필요한 필드

  • name
  • shape
  • dtype
  • residency  : global, shared, register
  • producer
  • consumers

 

5.3 MCNode

데이터 객체

from dataclasses import dataclass, field
from typing import List, Dict, Any

@dataclass
class MCNode:
    name: str
    op: str
    inputs: List[MCValue] = field(default_factory=list)
    outputs: List[MCValue] = field(default_factory=list)
    attrs: Dict[str, Any] = field(default_factory=dict)

초기 op 종류는 문자열로 시작한다

  • load_tile
  • compute_score
  • update_softmax
  • accumulate_output
  • store_tile

 

5.4 Region Base

모든 region 의 공통 구조

from dataclasses import dataclass, field
from typing import List

@dataclass
class Region:
    name: str
    kind: str
    inputs: List[MCValue] = field(default_factory=list)
    outputs: List[MCValue] = field(default_factory=list)
    nodes: List[MCNode] = field(default_factory=list)
    subregions: List["Region"] = field(default_factory=list)
    attrs: dict = field(default_factory=dict)

 

6. Minimal Region Kinds

초기 3 개의 region 종류

6.1 ExecutionRegion

semantic 단위 region

@dataclass
class ExecutionRegion(Region):
    def __init__(self, name: str):
        super().__init__(name=name, kind="execution")

  • attention_region

 

6.2 StreamingRegion

streaming pipeline 표현

@dataclass
class ExecutionRegion(Region):
    def __init__(self, name: str):
        super().__init__(name=name, kind="execution")

  • attention_stream

 

6.3 TileRegion

tile execution 단위

@dataclass
class TileRegion(Region):
    def __init__(self, name: str, tile_m: int, tile_n: int, tile_k: int):
        super().__init__(name=name, kind="tile")
        self.attrs["tile_m"] = tile_m
        self.attrs["tile_n"] = tile_n
        self.attrs["tile_k"] = tile_k

  • q_tile_loop

 

7. Optional Minimal Enum Types

얇은 enum

from enum import Enum

class Residency(str, Enum):
    GLOBAL = "global"
    SHARED = "shared"
    REGISTER = "register"
    
    
class MCNodeOp(str, Enum):
    LOAD_TILE = "load_tile"
    COMPUTE_SCORE = "compute_score"
    UPDATE_SOFTMAX = "update_softmax"
    ACCUMULATE_OUTPUT = "accumulate_output"
    STORE_TILE = "store_tile"

 

8. Builder Layer

간단한 builder 를 통해 dataclass 조립

class MCIRBuilder:
    def value(self, name, shape, dtype, residency="global"):
        return MCValue(name=name, shape=shape, dtype=dtype, residency=residency)

    def node(self, name, op, inputs=None, outputs=None, **attrs):
        return MCNode(
            name=name,
            op=op,
            inputs=inputs or [],
            outputs=outputs or [],
            attrs=attrs,
        )

    def execution_region(self, name):
        return ExecutionRegion(name)

    def streaming_region(self, name, stream_axis="sequence"):
        return StreamingRegion(name, stream_axis=stream_axis)

    def tile_region(self, name, tile_m, tile_n, tile_k):
        return TileRegion(name, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k)
  • pass 내부 코드 단순화
  • 테스트 코드 간소화
  • 향후 validation hook 추가

 

9. Pass 1 : Attention Pattern Detection

첫 pass 는 기존 semantic IR 를 읽어서 attention 패턴을 찾는다.

찾는 패턴

MatMul(Q, K^T)
→ Mask? (optional)
→ Softmax
→ MatMul(..., V)

9.1 기대 입력

기존 AICF operator graph 의 노드 리스트 또는 DAG

9.2 기대 출력

ExecutionRegion(kind="execution", name="attention_region")

9.3 최소 구현 방식

선형 패턴 매칭에서 시작

def detect_attention_pattern(ops: list) -> list[ExecutionRegion]:
    ...

출력 region 에는 semantic 정보만 넣는다.

region.attrs["pattern"] = "scaled_dot_product_attention"
region.attrs["has_mask"] = True
region.attrs["q_name"] = "Q"
region.attrs["k_name"] = "K"
region.attrs["v_name"] = "V"

 

10. Pass 2 : Streaming Lowering

이 pass 가 핵심

입력

  • ExecutionRegion(attention_region)

출력

  • 그 아래 StreamingRegion
  • 그 아래 TileRegion
  • 내부 MCNode

10.1 lowering 결과 목표

ExecutionRegion(attention)
└── StreamingRegion(attention_stream)
    └── TileRegion(q_tile_loop)
        ├── load_tile(Q)
        ├── load_tile(K)
        ├── load_tile(V)
        ├── compute_score
        ├── update_softmax
        ├── accumulate_output
        └── store_tile(O)

10.2 최소 lowering 정책

처음에는 고정값으로 시작 가능

tile_m = 128
tile_n = 64
tile_k = 64
stream_axis = "sequence"

이 값들은 나중에 autotune / arch-specific logic 으로 대체 가능

 

11. Examle Lowering Code Shape

대략 이러한 형태

def lower_attention_to_streaming(region: ExecutionRegion) -> ExecutionRegion:
    b = MCIRBuilder()

    stream = b.streaming_region("attention_stream", stream_axis="sequence")
    tile = b.tile_region("q_tile_loop", tile_m=128, tile_n=64, tile_k=64)

    q_tile = b.value("Q_tile", (128, 64), "fp16", residency="shared")
    k_tile = b.value("K_tile", (64, 64), "fp16", residency="shared")
    v_tile = b.value("V_tile", (64, 64), "fp16", residency="shared")
    score = b.value("score_frag", (128, 64), "fp32", residency="register")
    softmax_m = b.value("softmax_max", (128,), "fp32", residency="register")
    softmax_l = b.value("softmax_sum", (128,), "fp32", residency="register")
    acc = b.value("output_acc", (128, 64), "fp32", residency="register")
    out = b.value("O_tile", (128, 64), "fp16", residency="global")

    tile.nodes.extend([
        b.node("load_q", "load_tile", outputs=[q_tile], source="Q"),
        b.node("load_k", "load_tile", outputs=[k_tile], source="K"),
        b.node("load_v", "load_tile", outputs=[v_tile], source="V"),
        b.node("score", "compute_score", inputs=[q_tile, k_tile], outputs=[score]),
        b.node("softmax", "update_softmax", inputs=[score], outputs=[softmax_m, softmax_l]),
        b.node("accumulate", "accumulate_output", inputs=[score, v_tile], outputs=[acc]),
        b.node("store_o", "store_tile", inputs=[acc], outputs=[out], target="O"),
    ])

    stream.subregions.append(tile)
    region.subregions.append(stream)
    return region

중요한 건 semantic op chain 이 streaming execution structure 로 바뀌는 것이다.

 

12. Printer / Dump Design

예시 출력

ExecutionRegion(attention_region)
  attrs: pattern=scaled_dot_product_attention, has_mask=True

  StreamingRegion(attention_stream)
    attrs: stream_axis=sequence

    TileRegion(q_tile_loop)
      attrs: tile_m=128, tile_n=64, tile_k=64

      Node(load_q, op=load_tile)
        outputs: Q_tile[128,64]:fp16@shared

      Node(load_k, op=load_tile)
        outputs: K_tile[64,64]:fp16@shared

      Node(load_v, op=load_tile)
        outputs: V_tile[64,64]:fp16@shared

      Node(score, op=compute_score)
        inputs: Q_tile, K_tile
        outputs: score_frag[128,64]:fp32@register

      Node(softmax, op=update_softmax)
        inputs: score_frag
        outputs: softmax_max[128]:fp32@register, softmax_sum[128]:fp32@register

      Node(accumulate, op=accumulate_output)
        inputs: score_frag, V_tile
        outputs: output_acc[128,64]:fp32@register

      Node(store_o, op=store_tile)
        inputs: output_acc
        outputs: O_tile[128,64]:fp16@global

 

13. Suggestetd Files in Detail

values.py

from dataclasses import dataclass, field
from typing import Optional

@dataclass
class MCValue:
    name: str
    shape: tuple[int, ...]
    dtype: str
    residency: str = "global"
    producer: Optional[str] = None
    consumers: list[str] = field(default_factory=list)

nodes.py

from dataclasses import dataclass, field
from typing import Any

@dataclass
class MCNode:
    name: str
    op: str
    inputs: list = field(default_factory=list)
    outputs: list = field(default_factory=list)
    attrs: dict[str, Any] = field(default_factory=dict)

regions.py

from dataclasses import dataclass, field

@dataclass
class Region:
    name: str
    kind: str
    inputs: list = field(default_factory=list)
    outputs: list = field(default_factory=list)
    nodes: list = field(default_factory=list)
    subregions: list = field(default_factory=list)
    attrs: dict = field(default_factory=dict)

@dataclass
class ExecutionRegion(Region):
    def __init__(self, name: str):
        super().__init__(name=name, kind="execution")

@dataclass
class StreamingRegion(Region):
    def __init__(self, name: str, stream_axis: str = "sequence"):
        super().__init__(name=name, kind="streaming")
        self.attrs["stream_axis"] = stream_axis

@dataclass
class TileRegion(Region):
    def __init__(self, name: str, tile_m: int, tile_n: int, tile_k: int):
        super().__init__(name=name, kind="tile")
        self.attrs["tile_m"] = tile_m
        self.attrs["tile_n"] = tile_n
        self.attrs["tile_k"] = tile_k

module.py

from dataclasses import dataclass, field

@dataclass
class MCModule:
    regions: list = field(default_factory=list)

 

14. Pass API Design

함수형으로 시작

attention pattern pass

def run_attention_pattern_pass(semantic_graph) -> MCModule:
    ...

streaming lowering pass

def run_attention_pattern_pass(semantic_graph) -> MCModule:
    ...

향후에는 pass manager 로 감싸도 된다.

class MCIRPassManager:
    def __init__(self, passes):
        self.passes = passes

    def run(self, obj):
        for p in self.passes:
            obj = p(obj)
        return obj

 

15. Validation Rules

최소 validation 은 있어야 한다.

15.1 region validity

  • region name non-empty
  • region kind valid

15.2 node validity

  • node op non-empty
  • outputs count reasonable

15.3 value validity

  • shape defined
  • dtype defined
  • residency valid

15.4 attention region validity

  • Q/K/V 존재
  • output 존재
  • streaming region 1개 이상 생성

 

16. Test Strategy

초기 3개의 테스트

Test 1 : attention pattern detect

입력 semantic graph 에서 attention region 이 잡히는지

Test 2 : streaming lowering

attention region 아래 streaming/tile region 이 생기는지

Test 3 : dump snapshot

pinrter 출력이 기대 문자열과 대체로 일치하는지

 

17. Concrete First Milestone

Milestone A

  • dataclass 기반 MCIR 추가
  • printer 추가
  • hand-built attention MCIR 생성 스크립트

Milestone B

  • semantic graph pattern detect
  • automatic AttentionRegion 생성

Milestone C

  • AttentionRegion -> StreamingRegion lowering
  • dump 비교 가능

 

18. Example End-to-End Prototype Flow

semantic_graph = build_attention_semantic_graph()

module = run_attention_pattern_pass(semantic_graph)
module = run_streaming_lowering_pass(module)

print(dump_module(module))

//////////////////////////////////////////////////////

ExecutionRegion(attention_region)
  StreamingRegion(attention_stream)
    TileRegion(q_tile_loop)
      load_tile(Q)
      load_tile(K)
      load_tile(V)
      compute_score
      update_softmax
      accumulate_output
      store_tile(O)