Intermediate Representation 은 다음의 목적을 가진다.
- Python step_fn 을 1회 실행해서 실행 경로를 연산 노드 시퀀스로 고정
- IR 은 모델 옵티마이저 데이터가 실제로 어떤 순서로 호출되는지를 표현한다
- 이후 IR 을 backend op 시퀀스로 lowering 해
- 실행 순서 검증
- CUDA graph capture replay 준비
- 향후 정식 컴파일 기반으로 확장
IR 은 정적 그래프 빌더가 아니라 한 번 실행된 step 의 연산 로그를 구조화한 실행 IR 이다.
IR 구성 요소
1.1 IRValue (Tensor / Scaler 의 메타)
IRValue 는 런타임에 존재하는 텐서, 스칼라를 shape, dtype, device 기준으로 표현한다.
@dataclass
class IRValue:
id: int
name: str
shape: Tuple[int, ...]
dtype: str
device: str
- id : IR 내부에서 참조하는 value id
- name : 디버깅용 이름
- shape, dtype, device : 실행 안정성, 검증을 위한 최소 메타
중요
- Tensor identity 를 캐시 키로 사용해서 IRValue 폭증을 막는다
- 즉, 같은 Tensro 객체는 항상 같은 IRValue 로 매핑된다.
1.2 IRNode
IRNode 는 연산 1개를 표현한다
@dataclass
class IRNode:
id: int
op: str
inputs: List[int] # IRValue ids
outputs: List[int] # IRValue ids
attrs: Dict[str, Any]
- op : high-level op 이름 ( linear, relu, adam step... )
- inputs, outputs : IRValue id 로 연결
- attrs : op 별 파라미터
1.3 IRGraph
IRGraph 는 IRValue / IRNode 를 모아둔 단일 그래프다
- value : Dict[int, IRValue]
- nodes : List[IRNode]
- emit() 로 node 추가
- new_value() 로 value 추가
2. IR 생성 : tracing 의 전체 흐름
2.1 compile_ir(step_fn)
def compile_ir(step_fn, *, name="train_step") -> IRGraph:
ir = IRGraph(name=name)
with tracing(ir):
step_fn()
return ir
compiler_ir 이 IR 생성 시작점
- step_fn 을 한 번 실행
- 실제 연산을 돌리는 실행이 아닌 각 functional / op 에서 is_tracing() True 분기를 타며 IR 를 emit 한다.
- 즉, tracing 은 실제 계산이 아니라 IR 노드 생성이 목적이다.
2.2 tracing(ir) 컨텍스트 매니저
aicf_fw / core / trace.py
- 전역 플래그
- 전역 IR 포인터
이 구간에서
- is_tracing() == True
- get_ir() 로 현재 그래프에 접근 가능
3. 어디서 IRNode 가 만들어지는지?
(A) Functional forward / training helper op 들
functional.py 에 tracing 분기가 있음
Linear
- tracing 시
- 입력 x / W 를 IRValue 로 확보
- 출력 y 를 symbolic Tensor 로 만들고 IRValue 로 연결
- ir.emit(op="Linear",...)
의미
- 이 IRNode 는 고수준 Linear 연산을 나타낸다.
- 런타임 lowering 에서는 gemm + bias_add 로 풀릴 수 있다.
(B) Autograd.backward 의 tracing path
autograd.py
if is_tracing():
ir.emit(op="Backward", inputs=[loss,(grad)], outputs=[], attrs={"accumulate": ...})
return
- tracing 에선backward 그래프를 만들지 않는다.
- opaque node 추가
- 실제 backward 에서 실행되는 연산들은 runtime trace 에만 나타나고 IR 에는 나타나지 않는다.
즉
- Forward 연산은 high-level 로 기록
- Backward 는 opaque marker
- Optim step 은 high-level 로 기록
현재 학습 step 을 컴파일하기 위한 최소 구조만 기록하는 실행 IR 형태
IR compiler 은 capture 전에 한다. compiler 은 실행이 아닌 IR 생성
warmup 은 capture-safe 를 위해 필요한 버퍼, grad, state 를 미리 materialize 하는 단계
capture 는 실제 backend ops 가 캡처
replay 는 캡처된 CUDA Graph 를 반복 실행
IR 은 캡처되어야 할 실행 의미를 추상화한 기록