torch.cuda.CUDAGraph 기반 실행 경로를 추가
- execution plan 을 한 번 캡처하고
- 이후 반복 실행에선 graph.replay 로 커널 런칭 오버헤드 최소화
- tranining 모드에서는 optimizer step / param 이 feed 로 덮어써지지 않도록 안전한 포인터 / 복사 정책 보장
1. 전제 CUDA Graph 의 핵심 제약
캡처 시점에 사용된 GPU 메모리 포인터가 replay 에서도 동일해야 한다는 제약
매 step 마다 새로 생성되는 텐서를 그대로 입력으로 쓰면 포인터가 바뀌고 캡처 / 리플레이가 깨짐
외부 입력을 직접 쓰는 것이 아닌 고정 버퍼로 copy 해서 사용한다.
2. 역할 분리 input / param / state
Builder 에 externals 을 분리
- input : 매 스텝마다 바뀌는 입력
- param : W, b 같은 모델 파라미터
- state : optimizer 상태
이 셋을 합쳐 external_vids 로 관리해서 feed 로 들어오는 값의 전체 집합을 명확히 함
- inference : param 은 보통 고정, input 만 바뀜
- train : param / state 는 그래프 내부에서 누적 업데이트,
3. 컴파일 캐시 : plan identity 를 고정해야 CUDA Graph 캐시가 의미가 있음
- exe.run 이 매번 compile 을 다시 하면
- plan 이 매번 새 객체가 됨
- CUDA Graph 캐시 키가 매번 달라져서 캡처를 새로 하게 됨
- 결과적으로 state 가 누적되지 않음 / step 이 리셋 되는 것처럼 보이는 현상이 발생
해결
- compile_cached 를 두고 id 기준으로 compiled program 을 캐시
- run 은 항상 compile_cached 결과 사용
그래프 캐시보다 먼저 플랜 캐시가 고정되어야 한다.
4. CUDA Graph 캡처 / 리플레이 구조
4.1 캡처 단계
입력 : Model, CompiledProgram, feed, 정책
핵심 로직
- 모든 external 은 fixed buffer 를 가져야 함
- static_roles 에 포함된 role 들의 vid 를 모아서 고정 포인터 대상을 결정
- external_vids 중 static_roles 에 포함되지 않은게 있으면 에러
- CUDA Graph 포인터 안정성을 강제
- slots 생성
- externals : spec 기반으로 alloc_from_spec() 로 고정 버퍼를 할당하고, feed 에서 copy
- internals : spec 기반으로 할당
- alias 적용
- in-place 를 kernel runtime flag 로 직접 처리하지 않고
- 실행 전 slots = slots 로 동일 포인터를 공유하게 만들어 해결
- wramup
- 동일 버퍼를 사용하여 eager 로 몇 번 실행
- 메모리 초기화 / 커널 JIT / 캐시 등 불안정 요소 제거
- graph capture
- with torch.cuda.graph : run_lowered_ops
결과는 GraphCaptureProgram 로 보관
4.2 리플레이 단계 ( replay_cuda_graph )
- copy_roles 에 속한 role 만 feed 에서 ext_bufs 로 copy
- graph.replay
- outputs 는 b.outputs 에 등록된 vid 를 slots 에서 꺼내 반환
5. 모드별 정책 inference vs train
inference
- static_roles = input, param
- copy_roles = input .
- input 은 매 run 마다 바뀌니 copy
- param 은 보통 고정이므로 capture 때의 param 을 계속 사용
train
- static_roles = input, param, state
- copy_roles = input,
- param / state 는 포인터는 고정이어야 하므로 static 에는 포함
- 값은 그래프 내부에서 누적 업데이트 되므로 feed 로 덮어쓰면 안 됨
6. Graph cache key 설계
- mode
- plan_key
- static_roles / copy_roles
- feed signature
Adam_step 이 다른 이유
다중 alias 가 필요한 구조