본문 바로가기

dev_AI_framework

LLM 과 접목, 최적화 커널 구성과 빌드, 모듈 생성의 과정을 자동화

완전 가능해요—그리고 그게 이 프로젝트가 진짜 빛나는 지점이에요.
LLM을 서브그래프 인식 → 커널 설계 선택 → CUDA 코드 생성 → 빌드/바인딩 → 테스트/튜닝 → 아티팩트 배포까지 자동화 파이프라인에 녹이면 됩니다.

전체 파이프라인 스케치

  1. 그래프 분석
    • 모델에서 패턴 탐지 (예: [QKᵀ→softmax→dropout→PV], [LN→GEMM→GELU→GEMM] 등).
    • 텐서 크기/정밀도/장치 capability(SM, Tensor Core) 수집.
  2. LLM 플래너 (정책)
    • 입력: 서브그래프 IR + 하드웨어/정밀도 제약 + 성능 힌트.
    • 출력: “전략” 선택
      • composite vs epilogue-fused vs flash-like
      • dtype(FP16/BF16/FP32), accumulation(FP32), 타일 크기, 스케줄(CTA/warp) 등.
  3. LLM 코드젠
    • 선택된 전략 템플릿에 파라미터를 채워 CUDA/CUTLASS/regemm epilogue functor 생성.
    • 마스크/causal/dropout·PRNG·온라인 softmax 등 동작 포함.
  4. 빌드 & 바인딩
    • nvcc/clang-cuda 컴파일, 아카이브/링킹.
    • pybind 모듈에 자동 등록(서명 일치 보장), 심볼 검증.
  5. 검증 & 튜닝
    • 정확도: 참조 구현과 allclose(eps/atol 자동 조정).
    • 마이크로벤치: 크기 스윕, CUDA events로 ms 측정.
    • 오토튜닝: 타일/블록 크기, 전치/프리페치 전략 탐색 → 최적 해시 키로 캐시.
  6. 아티팩트 캐싱/배포
    • (shape, dtype, sm, strategy) → 빌드 캐시 키.
    • 성능 메타데이터/리그레션 기준 저장.
    • 재현 가능한 seed/컴파일 플래그 함께 기록.

간단한 워크플로 예시 (의사 YAML)

- detect_subgraph:
    pattern: "SDPA"
    inputs: [Q:BHMD, K:BHND, V:BHND]
- plan_with_llm:
    constraints: {sm: 90, dtype: fp16, accum: fp32, mem_cap: high}
    decide:
      strategy: "flash_like"
      tile: {M: 128, N: 64, K: 64}
      dropout: true
      causal: maybe
- codegen_with_llm:
    templates:
      - kernels/sdpa_flash_like.cu.j2
    render_to: build/sdpa_flash_like_128x64x64.cu
- build:
    nvcc: ["--use_fast_math", "-arch=sm_90", "-O3"]
    outputs: ["libsdpa_flash_like.a", "_core.pyd"]
- validate:
    ref_impl: "numpy"
    atol: 1e-4
- autotune:
    sweep: [{M:64,N:64},{M:128,N:64},{M:128,N:128}]
    pick: "best_latency"
- cache_publish:
    key: "sdpa:sm90:fp16:128x64x64:flash"
    store: ["binary","bindings","bench.json"]

LLM 프롬프트 설계 포인트

  • 제약 우선: SM, shared mem, reg pressure, occupancy 상한 제공.
  • 안전 가드: 메모리 경계, 동기화, warp divergence 최소화 규칙 명시.
  • 테스트 자동 생성: 랜덤 텐서/마스크/causal 케이스, 에지(비정방) 포함.
  • 성능 힌트: coalesced, LDS tiling, 온라인 softmax 수식, Philox RNG 스킴.

모듈/레지스트리 자동 등록

  • 생성된 커널이 ops_registry.add("sdpa", impl=flash_like, guard=shape/dtype/sm) 형태로 런타임 디스패치에 붙도록 스텁/매크로 생성.
  • 실패 시 composite 경로로 폴백.

위험 요소 & 대응

  • 정확성: mixed-precision/exp 안정성 → 온라인 softmax, FP32 누적 강제.
  • 재현성: dropout/PRNG → Philox key/offset 규약 고정, seed 기록.
  • 빌드 플래그 편차: nvcc 옵션/드라이버 버전 캐싱 → 메타데이터 고정.
  • 성능 회귀: 벤치 스냅샷/임계값 비교로 PR 차단.

기대 이득

  • 커널 런치 수/글로벌 메모리 왕복 횟수 대폭 감소.
  • 문제 크기/하드웨어별 자동 커널 특화.
  • 개발자가 “패턴만” 정의하면, 뒤는 LLM+파이프라인이 자동 생성/튜닝.