본문 바로가기

AI Compiler framework

CUDA Graph Capture / Replay 설계 정리

torch.cuda.CUDAGraph 기반 실행 경로를 추가

  • execution plan 을 한 번 캡처하고
  • 이후 반복 실행에선 graph.replay 로 커널 런칭 오버헤드 최소화
  • tranining 모드에서는 optimizer step / param 이 feed 로 덮어써지지 않도록 안전한 포인터 / 복사 정책 보장

 

1. 전제 CUDA Graph 의 핵심 제약

캡처 시점에 사용된 GPU 메모리 포인터가 replay 에서도 동일해야 한다는 제약

매 step 마다 새로 생성되는 텐서를 그대로 입력으로 쓰면 포인터가 바뀌고 캡처 / 리플레이가 깨짐

외부 입력을 직접 쓰는 것이 아닌 고정 버퍼로 copy 해서 사용한다.

 

2. 역할 분리 input / param / state

Builder 에 externals 을 분리

  • input : 매 스텝마다 바뀌는 입력
  • param : W, b 같은 모델 파라미터 
  • state : optimizer 상태

이 셋을 합쳐 external_vids 로 관리해서 feed 로 들어오는 값의 전체 집합을 명확히 함

  • inference : param 은 보통 고정, input 만 바뀜
  • train : param / state 는 그래프 내부에서 누적 업데이트, 

 

3. 컴파일 캐시 : plan identity 를 고정해야 CUDA Graph 캐시가 의미가 있음

  • exe.run 이 매번 compile 을 다시 하면
  • plan 이 매번 새 객체가 됨
  • CUDA Graph 캐시 키가 매번 달라져서 캡처를 새로 하게 됨
  • 결과적으로 state 가 누적되지 않음 / step 이 리셋 되는 것처럼 보이는 현상이 발생

해결

  • compile_cached 를 두고 id 기준으로 compiled program 을 캐시
  • run 은 항상 compile_cached 결과 사용

그래프 캐시보다 먼저 플랜 캐시가 고정되어야 한다.

 

4. CUDA Graph 캡처 / 리플레이 구조

4.1 캡처 단계

입력 : Model, CompiledProgram, feed, 정책

핵심 로직

  • 모든 external 은 fixed buffer 를 가져야 함
    • static_roles 에 포함된 role 들의 vid 를 모아서 고정 포인터 대상을 결정
    • external_vids 중 static_roles 에 포함되지 않은게 있으면 에러
      • CUDA Graph 포인터 안정성을 강제
  • slots 생성
    • externals : spec 기반으로 alloc_from_spec() 로 고정 버퍼를 할당하고, feed 에서 copy
    • internals : spec 기반으로 할당
  • alias 적용
    • in-place 를 kernel runtime flag 로 직접 처리하지 않고 
    • 실행 전 slots = slots 로 동일 포인터를 공유하게 만들어 해결
  • wramup
    • 동일 버퍼를 사용하여 eager 로 몇 번 실행
    • 메모리 초기화 / 커널 JIT / 캐시 등 불안정 요소 제거
  • graph capture
    • with torch.cuda.graph : run_lowered_ops

결과는 GraphCaptureProgram 로 보관

 

4.2 리플레이 단계 ( replay_cuda_graph )

  • copy_roles 에 속한 role 만 feed 에서 ext_bufs 로 copy
  • graph.replay
  • outputs 는 b.outputs 에 등록된 vid 를 slots 에서 꺼내 반환

 

5. 모드별 정책 inference vs train

inference

  • static_roles = input, param
  • copy_roles = input .
  • input 은 매 run 마다 바뀌니 copy
  • param 은 보통 고정이므로 capture 때의 param 을 계속 사용

 

train

  • static_roles = input, param, state
  • copy_roles = input, 
  • param / state 는 포인터는 고정이어야 하므로 static 에는 포함
  • 값은 그래프 내부에서 누적 업데이트 되므로 feed 로 덮어쓰면 안 됨

 

6. Graph cache key 설계

  • mode
  • plan_key
  • static_roles / copy_roles
  • feed signature

 

Adam_step 이 다른 이유

다중 alias 가 필요한 구조