본문 바로가기

AI Compiler framework

Adam_step 해결 완료, 현재 위치 ( IR - lowering - runtime capture / replay - parameter update - determinism 완료..!) - Adam까지 포함한 train-step CUDA Graph가 캡처/리플레이로 닫혔고, replay는 full-restore 기준 deterministic하다


=== PR3 verify: IR shapes + topo + lowering + runtime trace + determinism(full restore) ===
replay_n=20, seed=0, warmup_runs=2, print_loss_every=0
[OK] Torch forward loss = 1.016468406 (device=cuda)
[OK] Backend set: AICFBackend
[param] 0.W (8, 8) torch.float32 cuda:0
[param] 0.b (8,) torch.float32 cuda:0
[param] 2.W (8, 8) torch.float32 cuda:0
[param] 2.b (8,) torch.float32 cuda:0
=== IR DUMP ===
{
  "graph": "train_step_aicf_only",
  "values": {
    "0": {
      "id": 0,
      "name": "x",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "1": {
      "id": 1,
      "name": "W",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "2": {
      "id": 2,
      "name": "b",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "3": {
      "id": 3,
      "name": "linear_out",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "4": {
      "id": 4,
      "name": "linear_out",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "5": {
      "id": 5,
      "name": "relu_out",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "6": {
      "id": 6,
      "name": "relu_out",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "7": {
      "id": 7,
      "name": "W",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "8": {
      "id": 8,
      "name": "b",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "9": {
      "id": 9,
      "name": "linear_out",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "10": {
      "id": 10,
      "name": "t",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "11": {
      "id": 11,
      "name": "mse_grad_out",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "12": {
      "id": 12,
      "name": "linear_out",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "13": {
      "id": 13,
      "name": "mse_grad_out",
      "shape": [
        64,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "14": {
      "id": 14,
      "name": "step",
      "shape": [],
      "dtype": "torch.int32",
      "device": "cuda:0"
    },
    "15": {
      "id": 15,
      "name": "step",
      "shape": [],
      "dtype": "torch.int32",
      "device": "cuda:0"
    },
    "16": {
      "id": 16,
      "name": "bc1_inv",
      "shape": [],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "17": {
      "id": 17,
      "name": "bc2_inv",
      "shape": [],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "18": {
      "id": 18,
      "name": "grad",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "19": {
      "id": 19,
      "name": "m",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "20": {
      "id": 20,
      "name": "v",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "21": {
      "id": 21,
      "name": "p",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "22": {
      "id": 22,
      "name": "m",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "23": {
      "id": 23,
      "name": "v",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "24": {
      "id": 24,
      "name": "grad",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "25": {
      "id": 25,
      "name": "m",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "26": {
      "id": 26,
      "name": "v",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "27": {
      "id": 27,
      "name": "p",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "28": {
      "id": 28,
      "name": "m",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "29": {
      "id": 29,
      "name": "v",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "30": {
      "id": 30,
      "name": "grad",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "31": {
      "id": 31,
      "name": "m",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "32": {
      "id": 32,
      "name": "v",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "33": {
      "id": 33,
      "name": "p",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "34": {
      "id": 34,
      "name": "m",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "35": {
      "id": 35,
      "name": "v",
      "shape": [
        8,
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "36": {
      "id": 36,
      "name": "grad",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "37": {
      "id": 37,
      "name": "m",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "38": {
      "id": 38,
      "name": "v",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "39": {
      "id": 39,
      "name": "p",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "40": {
      "id": 40,
      "name": "m",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    },
    "41": {
      "id": 41,
      "name": "v",
      "shape": [
        8
      ],
      "dtype": "torch.float32",
      "device": "cuda:0"
    }
  },
  "nodes": [
    {
      "id": 0,
      "op": "Linear",
      "inputs": [
        0,
        1,
        2
      ],
      "outputs": [
        3
      ],
      "attrs": {
        "bias": true,
        "layout": "y = x @ W^T + b"
      }
    },
    {
      "id": 1,
      "op": "ReLU",
      "inputs": [
        4
      ],
      "outputs": [
        5
      ],
      "attrs": {}
    },
    {
      "id": 2,
      "op": "Linear",
      "inputs": [
        6,
        7,
        8
      ],
      "outputs": [
        9
      ],
      "attrs": {
        "bias": true,
        "layout": "y = x @ W^T + b"
      }
    },
    {
      "id": 3,
      "op": "MseGrad",
      "inputs": [
        4,
        10
      ],
      "outputs": [
        11
      ],
      "attrs": {}
    },
    {
      "id": 4,
      "op": "Backward",
      "inputs": [
        12,
        13
      ],
      "outputs": [],
      "attrs": {
        "accumulate": false
      }
    },
    {
      "id": 5,
      "op": "StepInc",
      "inputs": [
        14
      ],
      "outputs": [
        15
      ],
      "attrs": {}
    },
    {
      "id": 6,
      "op": "BiasCorr",
      "inputs": [
        15
      ],
      "outputs": [
        16,
        17
      ],
      "attrs": {
        "beta1": 0.9,
        "beta2": 0.999
      }
    },
    {
      "id": 7,
      "op": "AdamStep",
      "inputs": [
        1,
        18,
        19,
        20,
        16,
        17
      ],
      "outputs": [
        21,
        22,
        23
      ],
      "attrs": {
        "lr": 0.001,
        "beta1": 0.9,
        "beta2": 0.999,
        "eps": 1e-08
      }
    },
    {
      "id": 8,
      "op": "AdamStep",
      "inputs": [
        2,
        24,
        25,
        26,
        16,
        17
      ],
      "outputs": [
        27,
        28,
        29
      ],
      "attrs": {
        "lr": 0.001,
        "beta1": 0.9,
        "beta2": 0.999,
        "eps": 1e-08
      }
    },
    {
      "id": 9,
      "op": "AdamStep",
      "inputs": [
        7,
        30,
        31,
        32,
        16,
        17
      ],
      "outputs": [
        33,
        34,
        35
      ],
      "attrs": {
        "lr": 0.001,
        "beta1": 0.9,
        "beta2": 0.999,
        "eps": 1e-08
      }
    },
    {
      "id": 10,
      "op": "AdamStep",
      "inputs": [
        8,
        36,
        37,
        38,
        16,
        17
      ],
      "outputs": [
        39,
        40,
        41
      ],
      "attrs": {
        "lr": 0.001,
        "beta1": 0.9,
        "beta2": 0.999,
        "eps": 1e-08
      }
    }
  ]
}
[OK] [ir] shape consistency OK: Linear/ReLU/MseGrad
[OK] [ir] topo OK: single-producer SSA + use-after-define
[WARN] [ir][links] Backward loss vid=12 is not produced by any node (might be okay in v0)
[WARN] [ir][links] Backward grad vid=13 is not produced by any node (might be okay in v0)
[OK] [ir] links OK: Backward(loss,grad) are connected to forward graph
=== LOWERED OPS ===
[lower 00] op=gemm attrs={'transB': True}
[lower 01] op=bias_add attrs={}
[lower 02] op=relu attrs={}
[lower 03] op=gemm attrs={'transB': True}
[lower 04] op=bias_add attrs={}
[lower 05] op=mse_grad attrs={}
[lower 06] op=step_inc attrs={}
[lower 07] op=bias_corr attrs={'beta1': 0.9, 'beta2': 0.999}
[lower 08] op=adam_step attrs={'lr': 0.001, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08}
[lower 09] op=adam_step attrs={'lr': 0.001, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08}
[lower 10] op=adam_step attrs={'lr': 0.001, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08}
[lower 11] op=adam_step attrs={'lr': 0.001, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08}
[warmup] functional buffers = {'buffers': 12}
[aicf] capture_begin entered (dedicated stream)
[OK] [capture] done
=== TRACE OPS (runtime) ===
[trace 00] op=grad_zero
[trace 01] op=grad_zero
[trace 02] op=grad_zero
[trace 03] op=grad_zero
[trace 04] op=gemm
[trace 05] op=bias_add
[trace 06] op=relu
[trace 07] op=copy
[trace 08] op=gemm
[trace 09] op=bias_add
[trace 10] op=mse_grad
[trace 11] op=gemm
[trace 12] op=gemm
[trace 13] op=reduce_sum
[trace 14] op=copy
[trace 15] op=copy
[trace 16] op=relu_bwd
[trace 17] op=gemm
[trace 18] op=gemm
[trace 19] op=reduce_sum
[trace 20] op=copy
[trace 21] op=copy
[trace 22] op=step_inc
[trace 23] op=bias_corr
[trace 24] op=adam_step
[trace 25] op=adam_step
[trace 26] op=adam_step
[trace 27] op=adam_step
[OK] [lowering] match: forward slice OK, optim slice OK (adam_step x4)
[OK] [adam] state mutation OK on replay: max_param_diff=1.003595e-03, max_m=5.584040e-03, max_v=4.756113e-06
[A 00] stepdiff=1.003595e-03
[A 01] stepdiff=1.006724e-03
[A 02] stepdiff=1.010684e-03
[A 03] stepdiff=1.015453e-03
[A 04] stepdiff=1.021156e-03
[A 05] stepdiff=1.027618e-03
[A 06] stepdiff=1.034924e-03
[A 07] stepdiff=1.042824e-03
[A 08] stepdiff=1.051333e-03
[A 09] stepdiff=1.060132e-03
[A 10] stepdiff=1.069583e-03
[A 11] stepdiff=1.079831e-03
[A 12] stepdiff=1.090150e-03
[A 13] stepdiff=1.101691e-03
[A 14] stepdiff=1.112819e-03
[A 15] stepdiff=1.126227e-03
[A 16] stepdiff=1.139897e-03
[A 17] stepdiff=1.153755e-03
[A 18] stepdiff=1.167716e-03
[A 19] stepdiff=1.181846e-03
[B 00] stepdiff=1.003595e-03
[B 01] stepdiff=1.006724e-03
[B 02] stepdiff=1.010684e-03
[B 03] stepdiff=1.015453e-03
[B 04] stepdiff=1.021156e-03
[B 05] stepdiff=1.027618e-03
[B 06] stepdiff=1.034924e-03
[B 07] stepdiff=1.042824e-03
[B 08] stepdiff=1.051333e-03
[B 09] stepdiff=1.060132e-03
[B 10] stepdiff=1.069583e-03
[B 11] stepdiff=1.079831e-03
[B 12] stepdiff=1.090150e-03
[B 13] stepdiff=1.101691e-03
[B 14] stepdiff=1.112819e-03
[B 15] stepdiff=1.126227e-03
[B 16] stepdiff=1.139897e-03
[B 17] stepdiff=1.153755e-03
[B 18] stepdiff=1.167716e-03
[B 19] stepdiff=1.181846e-03
[OK] Determinism OK: 20 replays stepdiff-sequence matches
OK

 

1. 실행 결과 해석

IR 단계가 올바르게 생성됨

  • IR Dump 에 stepInc + biasCorr - adamstep 가 존재
  • adamstep 노드에 4개의 파라미터 생성 확인

optimizer 노드 emit 가 실제로 검증되었고 ir level 에서 AdamStep 누락이 없음

 

Lowering 단계가 올바르게 생성

  • forward : gemm, bias_add, relu, gemm, bias_add, mse_grad
  • optim : step_inc, bias_corr, adam_step x 4

optimizer 확장 확인, 

 

Runtime capture 에서 실제로 adamstep 이 dispatch/capture 됨

  • step_inc
  • bias_corr
  • adam_step

이건 deterministic 하게

  • adam_step 커널이 runtime 경로로 안 가는 문제가 해결
  • 중요한 것은 op 들이 capture 구간에서 호출

 

Replay 에서 상태 변화가 발생함

[OK] [adam] state mutation OK on replay:
  max_param_diff=1.003595e-03, max_m=5.584040e-03, max_v=4.756113e-06

 adam_step 이 캡처에 들어갔다는 최종 증거

  • replay 돌렸는데 파라미터 변화량 
  • m/v 의 변화, adam state update 가 동작
  • step 의 증가 검증 코드 

파라미터가 안 변하는 문제 해결

 

Determinism OK

A-run, B-run 의 stepdiff 시퀀스가 완전 동일

  • restore 후 replay 20 회 실행 시 매 step 동일한 시퀀스 재현 확인

의미!

  • 캡처 그래프가 상태 기반으로 완전히 결정적
  • 복원 가능한 상태 집합이 충분
  • 런타임 변동 외부 요인이 없고 pointer 안정성 확보

 

현재 프로젝트에서 “완료된 것 / 남은 것” 위치 정리

✅ 지금 완료된 큰 마일스톤

  1. IR 생성: forward + backward(opaque) + optimizer 노드 emit ✅
  2. IR shape/topo/link 검증(v0)
  3. Lowering: backend op 시퀀스 생성 ✅
  4. Runtime: C++ dispatch + CUDA Graph capture/replay ✅
  5. AdamStep: 실제 dispatch/capture 포함 ✅
  6. Replay determinism: full restore 기준 동일 시퀀스 ✅

즉, “훈련 step을 하나의 CUDA Graph로 캡처해서 deterministic하게 반복 실행”이 Adam까지 포함해서 완성.

 

다음 단계에서 개선할 수 있는 포인트는:

  • Backward 입력을 “forward의 실제 value id”로 연결하거나
  • “Backward는 inputs를 강제하지 않는 노드”로 schema를 분리하거나
  • compile_ir()에서 backward-link를 더 엄밀히 만들기

 

 

Adam까지 포함한 train-step CUDA Graph가 캡처/리플레이로 닫혔고, replay는 full-restore 기준 deterministic하다