본문 바로가기

dev_AI_framework

From FlashAttention to AI Framework Ops 문서

GPU Kernel Design & Analysis Knowledge Transfer

단순히 FlashAttention 구현 기록이 아닌, 하나의 고난도 op 를 구현하면서 얻은 GPU 커널 설계 원칙을 모든 AI ops 구현에 재사용하기 위한 기준 문서

 

1. FlashAttention 이 특별한 이유

FlashAttention 은 단순한 attention op 가 아니다

  • GEMM + reduction + exp + normalize
  • tile streaming
  • online statistics
  • strict numerical order
  • extreme memory pressure

즉, 대부분의 딥러닝 ops 가 겪는 문제를 한 커널에 압축한 사례

이를 통해 얻은 지식은 AI ops 전반에 일반화 가능

 

2. ops 커널 구현의 공통 문제 -> FlashAttention 에서의 해법

2.1 메모리 왕복 vs 연산 재계산

문제

  • intermediate tensor 저장은 느리다
  • 하지만 재계산은 연산 비용 증가

FlashAttention 에서의 해법

  • tile-local online softmax
  • 통계를 유지하면서 score 재사용

일반화

  • LayerNorm, Softmax, RMSNorm
  • Reduce + Normalize 계열 ops 전부 동일한 구조

AI framework 적용

  • ops 설계 시 global intermediate 를 만들지 ㅁ라 것
  • IR 에 requires_online_reduction 같은 semantic 필요

 

2.2 Kernel Fusion 의 진짜 기준

오해

  • 붙일 수 있으면 붙인다

FlashAttention 에서 배운 것

  • fusion 의 기준은 메모리 공유 + reduction 순ㅅ너
  • 단순히 연산 나열이 아님

일반화

  • bias + activation
  • dorpout + scale
  • norm + affine

AI framework 적용

  • ops kernel 구현 시
    • fusion 가능 조건을 코드가 아니라 메타데이터로 명시
  • rule-based matcher 의 근거가 됨

 

2.3 Warp 는 병렬 단위가 아니라 역할 단위다

FlashAttention 에서의 핵심 인식 전환

  • warp specialization = 연산 분업
  • warp 간 동기화 비용은 설계 비용

일반 ops 에 적용

  • redution-heavy ops
  • epilogue-heavy ops
  • mixed precision ops

AI framework 적용

  • 커널 capability 에
    • warp role
    • inter-warp dependency
  • IR 은 몰라도 됨, kernel selector 가 안다

 

3. Shared Memory 를 캐시로 생각하면 실패한다

FlashAttention 에서의 교훈

  • smem 은 저장소가 아니라 연산 스케줄의 일부
  • layout 이 알고리즘 그 자체

일반 ops 로의 확장

  • Conv2D im2col
  • GEMM tiling
  • LayerNorm row-wise reduction

AI framework 적용

  • ops 구현 가이드에
    • smem layout 명시 없으면 커널 불완전
  • layout 은 튜닝 옵션이 아니라 설계 요소

 

4. cp.async / pipeline 에 대한 현실적인 이해

FlashAttention 에서 얻은 결론

  • cp.async 는 latency hiding 수단
  • 구조가 없으면 효과 없음
  • pipeline 깊이는 warp 구조에 종속

일반화

  • 모든 load-heavy ops
  • embedding lookup
  • attention 외 streaming ops

AI framework 적용

  • 커널 capability
    • async stages
    • requires explicit pipeline
  • autotune 이전에 rule 기반으로 선택 가능

 

5. Nsight Compute 분석이 설계 도구인 이유

5.1 ncu 는 결과 분석기가 아니다

FlashAttention 에서 ncu 는

  • 성능 확인이 아닌
  • 설계 검증 도구이다.

5.2 실제로 신뢰하게 된 지표들

  • Scheduler stats - warp 구조 검증
  • Barrier stalls - 설계 결함 탐지
  • L1/TEX throughput - tile locality 확인

5.3 신뢰하지 않게 된 지표

  • SM Busy 단독
  • Occupancy 단독

AI framework 적용

  • ops 커널 PR 기준
    • ncu snapshot 첨부
    • 왜 이 수치가 나왔는지 설명 필수

 

6. IR 설계에 직접 반영되는 지식들

FlashAttention 구현을 통해 IR 에 필요하다고 확정된 것들

  • semantic constraints
    • causal
    • streaming
    • reduction order
  • determinism requiresments
  • fusion 가능성 힌트

IR 은

  • 커널을 생성하지 않는다
  • 커널을 고르는 조건만 제공한다?

 

7. Ops 구현 흐름 (FlashAttention 에서 추출한 표준 절차)

  1. 연산을 tile-local / global 로 분해
  2. reduction 순서 고정
  3. fusion 가능 범위 결정
  4. warp 역할 분해
  5. smem layout 설계
  6. pipeline 설계
  7. ncu로 구조 검증

이 순서가 깨지면 무조건 다시 만든다