1. Goal
이 문서의 목표는 AICF 에 Memory-Centric IR (MCIR) 를 도입하기 위한 최소 구현 단위를 정의하는 것이다.
FlashAttention-like execution structure 를 완전 구현하는 것이 아니라. Attention semantic graph 를 MCIR 로 재구성할 수 있음을 증명하는 것
2. Minimal Scope
우선 제한됨 범위의 구현
- MCIR 기본 데이터 구조
- MCIR region hierarchy
- Attention pattern detection
- AttentionRegion 생성
- StreamingRegion 생성
- 간단한 TileRegion 생성
- MCIR pretty printer / dump
아직 안 할 것
- 실제 autotuning
- 복잡한 cost model
- cp.async modeling
- full kernel codegen
- architecture-specific scheduling
- backward pass
- generic all-op support
목표
Semantic Attention Graph
↓
AttentionPatternPass
↓
AttentionRegion
↓
StreamingLoweringPass
↓
StreamingRegion + TileRegion
↓
MCIR Dump
3. Implementation Principles
3.1 Semantic IR 는 그대로 둔다.
기존 AICF semantic / operator graph 를 버리지 않는다.
MCIR 은 그 위에 새로 생기는 중간 계층이다.
3.2 Python prototype 우선
처음에는 Python dataclass 기반으로 빠르게 구조를 고정한다.
3.3. IR dump 가 핵심 산출물
초기 단계에서는 실행보다 MCIR 이 잘 생성되었는지 눈으로 검증하는 것이 더 중요
3.4 Attention 하나만 먼저
Attention 전용 pass 를 먼저 만들고 일반화를 진행한다.
4. Proposed Directory Layout
AICF 내부에 최소한 구현
python/aicf_v2/src/aicf_v2/mcir/
├── __init__.py
├── types.py
├── values.py
├── nodes.py
├── regions.py
├── module.py
├── printer.py
├── builder.py
└── passes/
├── __init__.py
├── attention_pattern.py
└── streaming_lowering.py
역할 정리
- types.py
- 공통 enum, type alias
- values.py
- MCValue 정의
- nodes.py
- MCNode 정의
- regions.py
- ExecutionRegion, StreamingRegion, TileRegion
- module.py
- MCIR 전체 컨테이너
- printer.py
- dump / pretty print
- builder.py
- region/node/value 생성 helper
- passes/attention_pattern.py
- semantic graph에서 attention pattern 탐지
- passes/streaming_lowering.py
- AttentionRegion → StreamingRegion 변환
5. Core Object Model
최소 구현에서는 객체 수를 적게 가져가는 것이 좋다.
5.1 MCModule
전체 MCIR 를 담는 루트 객체
from dataclasses import dataclass, field
from typing import List
@dataclass
class MCModule:
regions: List["Region"] = field(default_factory=list)
역할
- top-level region 보관
- dump 진입점
- 향후 pass 입력 / 출력 컨테이너
5.2 MCValue
데이터 객체
from dataclasses import dataclass
from typing import Optional, List
@dataclass
class MCValue:
name: str
shape: tuple[int, ...]
dtype: str
residency: str = "global"
producer: Optional[str] = None
consumers: List[str] = None
def __post_init__(self):
if self.consumers is None:
self.consumers = []
초기 버전에서 필요한 필드
- name
- shape
- dtype
- residency : global, shared, register
- producer
- consumers
5.3 MCNode
데이터 객체
from dataclasses import dataclass, field
from typing import List, Dict, Any
@dataclass
class MCNode:
name: str
op: str
inputs: List[MCValue] = field(default_factory=list)
outputs: List[MCValue] = field(default_factory=list)
attrs: Dict[str, Any] = field(default_factory=dict)
초기 op 종류는 문자열로 시작한다
- load_tile
- compute_score
- update_softmax
- accumulate_output
- store_tile
5.4 Region Base
모든 region 의 공통 구조
from dataclasses import dataclass, field
from typing import List
@dataclass
class Region:
name: str
kind: str
inputs: List[MCValue] = field(default_factory=list)
outputs: List[MCValue] = field(default_factory=list)
nodes: List[MCNode] = field(default_factory=list)
subregions: List["Region"] = field(default_factory=list)
attrs: dict = field(default_factory=dict)
6. Minimal Region Kinds
초기 3 개의 region 종류
6.1 ExecutionRegion
semantic 단위 region
@dataclass
class ExecutionRegion(Region):
def __init__(self, name: str):
super().__init__(name=name, kind="execution")
예
- attention_region
6.2 StreamingRegion
streaming pipeline 표현
@dataclass
class ExecutionRegion(Region):
def __init__(self, name: str):
super().__init__(name=name, kind="execution")
예
- attention_stream
6.3 TileRegion
tile execution 단위
@dataclass
class TileRegion(Region):
def __init__(self, name: str, tile_m: int, tile_n: int, tile_k: int):
super().__init__(name=name, kind="tile")
self.attrs["tile_m"] = tile_m
self.attrs["tile_n"] = tile_n
self.attrs["tile_k"] = tile_k
예
- q_tile_loop
7. Optional Minimal Enum Types
얇은 enum
from enum import Enum
class Residency(str, Enum):
GLOBAL = "global"
SHARED = "shared"
REGISTER = "register"
class MCNodeOp(str, Enum):
LOAD_TILE = "load_tile"
COMPUTE_SCORE = "compute_score"
UPDATE_SOFTMAX = "update_softmax"
ACCUMULATE_OUTPUT = "accumulate_output"
STORE_TILE = "store_tile"
8. Builder Layer
간단한 builder 를 통해 dataclass 조립
class MCIRBuilder:
def value(self, name, shape, dtype, residency="global"):
return MCValue(name=name, shape=shape, dtype=dtype, residency=residency)
def node(self, name, op, inputs=None, outputs=None, **attrs):
return MCNode(
name=name,
op=op,
inputs=inputs or [],
outputs=outputs or [],
attrs=attrs,
)
def execution_region(self, name):
return ExecutionRegion(name)
def streaming_region(self, name, stream_axis="sequence"):
return StreamingRegion(name, stream_axis=stream_axis)
def tile_region(self, name, tile_m, tile_n, tile_k):
return TileRegion(name, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k)
- pass 내부 코드 단순화
- 테스트 코드 간소화
- 향후 validation hook 추가
9. Pass 1 : Attention Pattern Detection
첫 pass 는 기존 semantic IR 를 읽어서 attention 패턴을 찾는다.
찾는 패턴
MatMul(Q, K^T)
→ Mask? (optional)
→ Softmax
→ MatMul(..., V)
9.1 기대 입력
기존 AICF operator graph 의 노드 리스트 또는 DAG
9.2 기대 출력
ExecutionRegion(kind="execution", name="attention_region")
9.3 최소 구현 방식
선형 패턴 매칭에서 시작
def detect_attention_pattern(ops: list) -> list[ExecutionRegion]:
...
출력 region 에는 semantic 정보만 넣는다.
region.attrs["pattern"] = "scaled_dot_product_attention"
region.attrs["has_mask"] = True
region.attrs["q_name"] = "Q"
region.attrs["k_name"] = "K"
region.attrs["v_name"] = "V"
10. Pass 2 : Streaming Lowering
이 pass 가 핵심
입력
- ExecutionRegion(attention_region)
출력
- 그 아래 StreamingRegion
- 그 아래 TileRegion
- 내부 MCNode
10.1 lowering 결과 목표
ExecutionRegion(attention)
└── StreamingRegion(attention_stream)
└── TileRegion(q_tile_loop)
├── load_tile(Q)
├── load_tile(K)
├── load_tile(V)
├── compute_score
├── update_softmax
├── accumulate_output
└── store_tile(O)
10.2 최소 lowering 정책
처음에는 고정값으로 시작 가능
tile_m = 128
tile_n = 64
tile_k = 64
stream_axis = "sequence"
이 값들은 나중에 autotune / arch-specific logic 으로 대체 가능
11. Examle Lowering Code Shape
대략 이러한 형태
def lower_attention_to_streaming(region: ExecutionRegion) -> ExecutionRegion:
b = MCIRBuilder()
stream = b.streaming_region("attention_stream", stream_axis="sequence")
tile = b.tile_region("q_tile_loop", tile_m=128, tile_n=64, tile_k=64)
q_tile = b.value("Q_tile", (128, 64), "fp16", residency="shared")
k_tile = b.value("K_tile", (64, 64), "fp16", residency="shared")
v_tile = b.value("V_tile", (64, 64), "fp16", residency="shared")
score = b.value("score_frag", (128, 64), "fp32", residency="register")
softmax_m = b.value("softmax_max", (128,), "fp32", residency="register")
softmax_l = b.value("softmax_sum", (128,), "fp32", residency="register")
acc = b.value("output_acc", (128, 64), "fp32", residency="register")
out = b.value("O_tile", (128, 64), "fp16", residency="global")
tile.nodes.extend([
b.node("load_q", "load_tile", outputs=[q_tile], source="Q"),
b.node("load_k", "load_tile", outputs=[k_tile], source="K"),
b.node("load_v", "load_tile", outputs=[v_tile], source="V"),
b.node("score", "compute_score", inputs=[q_tile, k_tile], outputs=[score]),
b.node("softmax", "update_softmax", inputs=[score], outputs=[softmax_m, softmax_l]),
b.node("accumulate", "accumulate_output", inputs=[score, v_tile], outputs=[acc]),
b.node("store_o", "store_tile", inputs=[acc], outputs=[out], target="O"),
])
stream.subregions.append(tile)
region.subregions.append(stream)
return region
중요한 건 semantic op chain 이 streaming execution structure 로 바뀌는 것이다.
12. Printer / Dump Design
예시 출력
ExecutionRegion(attention_region)
attrs: pattern=scaled_dot_product_attention, has_mask=True
StreamingRegion(attention_stream)
attrs: stream_axis=sequence
TileRegion(q_tile_loop)
attrs: tile_m=128, tile_n=64, tile_k=64
Node(load_q, op=load_tile)
outputs: Q_tile[128,64]:fp16@shared
Node(load_k, op=load_tile)
outputs: K_tile[64,64]:fp16@shared
Node(load_v, op=load_tile)
outputs: V_tile[64,64]:fp16@shared
Node(score, op=compute_score)
inputs: Q_tile, K_tile
outputs: score_frag[128,64]:fp32@register
Node(softmax, op=update_softmax)
inputs: score_frag
outputs: softmax_max[128]:fp32@register, softmax_sum[128]:fp32@register
Node(accumulate, op=accumulate_output)
inputs: score_frag, V_tile
outputs: output_acc[128,64]:fp32@register
Node(store_o, op=store_tile)
inputs: output_acc
outputs: O_tile[128,64]:fp16@global
13. Suggestetd Files in Detail
values.py
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class MCValue:
name: str
shape: tuple[int, ...]
dtype: str
residency: str = "global"
producer: Optional[str] = None
consumers: list[str] = field(default_factory=list)
nodes.py
from dataclasses import dataclass, field
from typing import Any
@dataclass
class MCNode:
name: str
op: str
inputs: list = field(default_factory=list)
outputs: list = field(default_factory=list)
attrs: dict[str, Any] = field(default_factory=dict)
regions.py
from dataclasses import dataclass, field
@dataclass
class Region:
name: str
kind: str
inputs: list = field(default_factory=list)
outputs: list = field(default_factory=list)
nodes: list = field(default_factory=list)
subregions: list = field(default_factory=list)
attrs: dict = field(default_factory=dict)
@dataclass
class ExecutionRegion(Region):
def __init__(self, name: str):
super().__init__(name=name, kind="execution")
@dataclass
class StreamingRegion(Region):
def __init__(self, name: str, stream_axis: str = "sequence"):
super().__init__(name=name, kind="streaming")
self.attrs["stream_axis"] = stream_axis
@dataclass
class TileRegion(Region):
def __init__(self, name: str, tile_m: int, tile_n: int, tile_k: int):
super().__init__(name=name, kind="tile")
self.attrs["tile_m"] = tile_m
self.attrs["tile_n"] = tile_n
self.attrs["tile_k"] = tile_k
module.py
from dataclasses import dataclass, field
@dataclass
class MCModule:
regions: list = field(default_factory=list)
14. Pass API Design
함수형으로 시작
attention pattern pass
def run_attention_pattern_pass(semantic_graph) -> MCModule:
...
streaming lowering pass
def run_attention_pattern_pass(semantic_graph) -> MCModule:
...
향후에는 pass manager 로 감싸도 된다.
class MCIRPassManager:
def __init__(self, passes):
self.passes = passes
def run(self, obj):
for p in self.passes:
obj = p(obj)
return obj
15. Validation Rules
최소 validation 은 있어야 한다.
15.1 region validity
- region name non-empty
- region kind valid
15.2 node validity
- node op non-empty
- outputs count reasonable
15.3 value validity
- shape defined
- dtype defined
- residency valid
15.4 attention region validity
- Q/K/V 존재
- output 존재
- streaming region 1개 이상 생성
16. Test Strategy
초기 3개의 테스트
Test 1 : attention pattern detect
입력 semantic graph 에서 attention region 이 잡히는지
Test 2 : streaming lowering
attention region 아래 streaming/tile region 이 생기는지
Test 3 : dump snapshot
pinrter 출력이 기대 문자열과 대체로 일치하는지
17. Concrete First Milestone
Milestone A
- dataclass 기반 MCIR 추가
- printer 추가
- hand-built attention MCIR 생성 스크립트
Milestone B
- semantic graph pattern detect
- automatic AttentionRegion 생성
Milestone C
- AttentionRegion -> StreamingRegion lowering
- dump 비교 가능
18. Example End-to-End Prototype Flow
semantic_graph = build_attention_semantic_graph()
module = run_attention_pattern_pass(semantic_graph)
module = run_streaming_lowering_pass(module)
print(dump_module(module))
//////////////////////////////////////////////////////
ExecutionRegion(attention_region)
StreamingRegion(attention_stream)
TileRegion(q_tile_loop)
load_tile(Q)
load_tile(K)
load_tile(V)
compute_score
update_softmax
accumulate_output
store_tile(O)