본문 바로가기

AI Compiler framework

IR Graph 생성 과정과 그 의미

Intermediate Representation 은 다음의 목적을 가진다.

  • Python step_fn 을 1회 실행해서 실행 경로를 연산 노드 시퀀스로 고정
  • IR 은 모델 옵티마이저 데이터가 실제로 어떤 순서로 호출되는지를 표현한다
  • 이후 IR 을 backend op 시퀀스로 lowering 해 
    • 실행 순서 검증
    • CUDA graph capture replay 준비
    • 향후 정식 컴파일 기반으로 확장

IR 은 정적 그래프 빌더가 아니라 한 번 실행된 step 의 연산 로그를 구조화한 실행 IR 이다.

 

IR 구성 요소

1.1 IRValue (Tensor / Scaler 의 메타)

IRValue 는 런타임에 존재하는 텐서, 스칼라를 shape, dtype, device 기준으로 표현한다.

@dataclass
class IRValue:
    id: int
    name: str
    shape: Tuple[int, ...]
    dtype: str
    device: str
  • id : IR 내부에서 참조하는 value id
  • name : 디버깅용 이름
  • shape, dtype, device : 실행 안정성, 검증을 위한 최소 메타

중요

  • Tensor identity 를 캐시 키로 사용해서 IRValue 폭증을 막는다
  • 즉, 같은 Tensro 객체는 항상 같은 IRValue 로 매핑된다.

 

1.2 IRNode

IRNode 는 연산 1개를 표현한다

@dataclass
class IRNode:
    id: int
    op: str
    inputs: List[int]   # IRValue ids
    outputs: List[int]  # IRValue ids
    attrs: Dict[str, Any]
  • op : high-level op 이름 ( linear, relu, adam step... )
  • inputs, outputs : IRValue id 로 연결
  • attrs : op 별 파라미터

 

1.3 IRGraph

IRGraph 는 IRValue / IRNode 를 모아둔 단일 그래프다

  • value : Dict[int, IRValue]
  • nodes : List[IRNode]
  • emit() 로 node 추가
  • new_value() 로 value 추가

 

2. IR 생성 : tracing 의 전체 흐름

2.1 compile_ir(step_fn)

def compile_ir(step_fn, *, name="train_step") -> IRGraph:
    ir = IRGraph(name=name)
    with tracing(ir):
        step_fn()
    return ir

compiler_ir 이 IR 생성 시작점

  • step_fn 을 한 번 실행
  • 실제 연산을 돌리는 실행이 아닌 각 functional / op 에서 is_tracing() True 분기를 타며 IR 를 emit 한다.
  • 즉, tracing 은 실제 계산이 아니라 IR 노드 생성이 목적이다.

 

2.2 tracing(ir) 컨텍스트 매니저

aicf_fw / core / trace.py

  • 전역 플래그
  • 전역 IR 포인터

이 구간에서

  • is_tracing() == True
  • get_ir() 로 현재 그래프에 접근 가능

 

3. 어디서 IRNode 가 만들어지는지?

(A) Functional forward / training helper op 들

functional.py 에 tracing 분기가 있음

Linear

  • tracing 시
    • 입력 x / W 를 IRValue 로 확보
    • 출력 y 를 symbolic Tensor 로 만들고 IRValue 로 연결
    • ir.emit(op="Linear",...)

의미

  • 이 IRNode 는 고수준 Linear 연산을 나타낸다.
  • 런타임 lowering 에서는 gemm + bias_add 로 풀릴 수 있다.

 

(B) Autograd.backward 의 tracing path

autograd.py

if is_tracing():
    ir.emit(op="Backward", inputs=[loss,(grad)], outputs=[], attrs={"accumulate": ...})
    return
  • tracing 에선backward 그래프를 만들지 않는다.
  • opaque node 추가
  • 실제 backward 에서 실행되는 연산들은 runtime trace 에만 나타나고 IR 에는 나타나지 않는다.

  • Forward 연산은 high-level 로 기록
  • Backward 는 opaque marker
  • Optim step 은 high-level 로 기록

 

현재 학습 step 을 컴파일하기 위한 최소 구조만 기록하는 실행 IR 형태

 

IR compiler 은 capture 전에 한다. compiler 은 실행이 아닌 IR 생성

warmup 은 capture-safe 를 위해 필요한 버퍼, grad, state 를 미리 materialize 하는 단계

capture 는 실제 backend ops 가 캡처

replay 는 캡처된 CUDA Graph 를 반복 실행

IR 은 캡처되어야 할 실행 의미를 추상화한 기록