본문 바로가기

dev_AI_framework

gemm 헬퍼 모듈 요약 및 graph_capture 정리

헬퍼설명
assert_f32_2d 입력이 float32 dtype, 2D 행렬 형태인지 검증 후 (rows, cols) 반환.
as_tensor_2d CuPy 배열을 내부 C++ 바인딩용 ai::Tensor 구조로 래핑.
empty_2d / empty_like_2d 지정된 (M, N) 크기 또는 기존 배열과 동일한 shape의 빈 float32 CuPy 배열 생성.
get_stream_ptr CuPy Stream 또는 None을 받아 CUDA stream의 원시 포인터(ptr) 반환.
to_voidp_capsule 정수 포인터를 PyCapsule 형태로 변환하여 pybind11 함수에 전달 가능하게 함.
ensure_cuda_dlls (Windows용) CUDA DLL이 로드되지 않았을 경우 경로를 추가하여 ImportError 방지.
_parse_act_to_kind 문자열 "relu", "gelu" 등을 C++ enum ActKind으로 매핑.
_mk_attrs act, with_bias, leaky_slope 등을 설정한 GemmAttrs 객체 생성.
GemmWorkspaces / ensure_workspaces Capture-safe backward용 dZ / Lt workspace 버퍼를 관리·캐싱.
clear_ws_cache 캐시된 워크스페이스를 비워 GPU 메모리 반환 유도.

🔹 CUDA Graph Capture 방법 정리

목적:
cudaGraph를 이용해 GEMM 연산(Forward+Backward)을 메모리 할당 없이 캡처하여
GPU 재사용 시 오버헤드를 최소화함.


✅ 절차 요약

  1. 모든 버퍼 사전 할당
  2.  
    A, B, bias = cp.ascontiguousarray(A), cp.ascontiguousarray(B), cp.ascontiguousarray(bias) Y = cp.empty((M, N), dtype=cp.float32) Z = cp.empty((M, N), dtype=cp.float32) gY = cp.random.randn(M, N).astype(cp.float32) gA, gB, gBias = cp.empty_like(A), cp.empty_like(B), cp.empty((1, N), dtype=cp.float32)
  3. 워크스페이스 준비 (dZ, lt_ws)
  4.  
    ws = gemm_ops.ensure_workspaces(M, N, lt_bytes=4<<20)
  5. 스트림 생성 및 캡처 시작
  6.  
    stream = cp.cuda.Stream(non_blocking=True) with stream: stream.begin_capture() # CUDA Graph 시작
  7. 캡처-safe 함수 호출
  8.  
    gemm_ops.forward_into(A, B, out=Y, bias=bias, with_bias=True, act="relu", save_z=True, z_out=Z) gemm_ops.backward_into(A, B, gY, Z, with_bias=True, act="relu", gA_out=gA, gB_out=gB, gBias_out=gBias, work_dZ=ws.dZ, lt_workspace=ws.lt_ws)
  9. 캡처 종료 → 그래프 객체 생성
  10.  
    graph = stream.end_capture() exec_graph = graph.instantiate() # CuPy 버전에 따라 생략 가능
  11. 그래프 실행
  12.  
    exec_graph.launch(stream) # or stream.ptr stream.synchronize()

⚙️ 핵심 포인트

  • *_into 계열은 할당 금지(strict no-allocation) 경로 — 모든 텐서/버퍼는 미리 준비해야 함.
  • ensure_workspaces()로 dZ (필수)Lt workspace (선택) 를 미리 생성.
  • capture 내에서는 cupy.empty / astype / reshape 금지, 모두 외부에서 수행.
  • forward_into → backward_into 호출 순서를 지켜야 그래프 캡처가 일관됨.
  • graph.instantiate() / graph.launch(stream)는 CuPy 버전별로 다를 수 있으므로 호환 헬퍼로 처리.

요약하자면:

“capture-safe 경로 = 모든 텐서·워크스페이스·스트림을 사전 준비 + forward_into / backward_into를 연속 호출 + stream.capture()로 캡처 후 실행.”