- FlashAttention (대부분의 고성능 ops) 은 연산이 아니라 데이터가 언제/어떤 형태로/누가/어디서 접근하느냐가 성능을 결정한다
- 그걸 결정하는 게 shared memory layout + 그에 맞춘 thread/warp 매핑이다.
- 그래서 layout 이 바뀌면 접근 패턴, 동기화, 파이프라인 깊이, 수치 알고리즘 구현 방식까지 함께 바뀐다.
1) smem 을 캐시라고 생각하면 왜 실패하나
L1/L2 캐시는 하드웨어가 알아서
- 어떤 라인이 들어오고,
- 누가 공유하고
- 충돌을 어느 정도 숨겨주고
- 실패해도 그냥 느려질 뿐
하지만 smem 은 직접 설계해야 한다
- 어떤 데이터가 smem 에 들어오는가
- 언제 올리는가
- 누가 쓰는가
- 어떤 stride 로 접근하는가?
- 얼마나 오래 들고 있는가
- 언제 덮어쓰는가
저장소가 아닌 스케줄 테이블이다.
2) layout 이 알고리즘인 이유, 바꾸면 수학은 같아도 실행이 달라짐
FlashAttention 의 본질은
- S = QK^T
- softmax(S)
- O += P*V
- 마지막에 normalize
여기서 성능/구현 난이도를 결정하는 건 수식이 아닌
- Q / K / V 를 smem 에 어떤 형태로 배치하느냐
- warp 들이 그걸 어떤 순서로 소비하느냐
- softmax 의 intermediate 를 어디에 두고 누가 읽느냐
예시 1) K 를 transpose + pad 하는 순간 알고리즘이 변한다
K 를 글로벌에서 그대로 읽으면
- warp 가 K 를 읽을 때 stride 가 꼬여서 smem bank conflict / 비정렬 load / 재사용 감소가 생김
K 를 smem 에 transpose 해서 저장하면
- 이후 모든 접근이 행 방향 연속이 되어
- load 벡터화 + bank conflict 제거 + 재사용 상승이 가능해짐
pad 를 넣으면
- 특정 stride 에서 발생하는 정기적 bank aliasing 자체가 사라짐
이건 단순 최적화가 아니라
- 이 타일을 어떤 비용 구조로 곱할 것인가 라는 연산 방법이 바뀐 것
예시 2) exp(score) 를 smem 에 half 로 저장하는 순간 softmax 구현이 바뀐다.
exp 를 매번 재계산하면
- 연산량 증가
- 하지만 smem 저장 / 로드는 줄어듦
smem 에 half 로 저장할 시
- warp1 이 V 누적할 때 재계산 없이 바로 사용 가능
- 대신 smem 트래픽 / 용량 / bank 패턴이 요구됨
- softmax stat 업데이트와 동기화 시점이 고정됨
즉 exp 를 어디에 두냐는 선택은
- online softmax 의 데이터 플로우를 결정
- warp specializtion 가능/불가능을 결정
- barrier 필요 여부를 결정
3) FlashAttention 에서 smem layout 이 스케줄을 결정하는 3가지 축
A. 접근 패턴이 warp 역할을 강제한다.
- warp0 가 score+stats 를 만든다면, 그 결과를 warp1 이 소비해야 함
- 그 소비 위치가 reg 인지 smem 인지에 따라
- 통신 방식
- barrier 필요 여부
- 파이프라인 깊이 가 고정된다.
B. stageing (2, 3 - stage) 이 smem 주소 계획으로 굳어진다
pipeline 을 한다는 건 결국
- smem_stage[s] 에 load
- smem_stage[s-1] 로 compute
- 다음 iteration 에 overwirtie 이 덮어쓰기 규칙이 layout 에 박힌다.
그래서 stage 수를 바꾸면
- smem footprint 가 변하고
- 주소 계산이 바뀌고
- bank conflict 패턴이 바뀌고
- 레지스터 생존 기간도 바뀐다.
즉, 파이프라인으 옵션이 아니라 layout 과 한 몸이다.
c. 벡터화 단위가 thread mapping 을 결정한다
half2 로 K 를 저장하면
- lane 들이 2-element 단위로 책임지게 되고
- index 계산이 그 단위로 바뀌고
- 일부 lane 은 쓸모 없는 일을 안 하게 재배치해야 한다.
그래서 벡터화는 단순히 더 빠르다가 아니라
- thread responsibility 를 다시 짜는 것 = 알고리즘 재설계다
4) 그럼 어떻게 설계하냐 : 실전 가이드
Step 1) smem 에 올릴 데이터를 계산 단위로 정의해라
- 타일에서 재사용되는 최소 단위가 뭔지 먼저 정한다
step 2) 그 단위를 warp 가 읽기 좋은 형태로 재배치해라
- transpos/pad 는 대부분 여기서 나옴
- 목표는 3 개
- bank conflict 제거/완화
- 연속 접근
- warp 간 중복 로드 제거
Step 3) pipeline stage 를 정하고 smem 주소를 stage 로 분할해라
- base + stage * stride 형태로 덮어쓰기 안전을 설계
- stage 수는 숨기려는 latency 가 아니라 smem footprint 와 동기화 비용으로 결정
Step 4) intermediate 를 reg vs smem 중 어디에 둘 지 결정
- reg 에 두면 : 빠름, 대신 reg pressuer 증가, occupancy 감소
- smem 에 두면 : 통신 쉬움, 대신 smem traffic 증가, bank conflict 위험 증가
FlashAttention 에선 warp specialization 때문에 intermediate 를 smem 으로 보내는 선택이 자연스럽게 나오기도 함
5) AI framework ops 구현으로의 일반화
- LayerNorm / Softmax / RMSNorm
- online reduction + normalize
- stats 를 어디에 두느냐가 알고리즘
- GEMM epilogue
- output tile 을 smem 으로 staging 하느냐 vs 바로 store 하느냐
- vectorized store 를 위해 layout 을 강제함
- Conv2D
- smem 에 올리는 건 입력이 아니라 사실상 타일링된 가공 데이터
- im2col 을 명시적으로 만들지 않더라도 layout 이 im2ccol 역할을 함
'GPU-KERNEL' 카테고리의 다른 글
| AICF Kernel Engineering Report ( GEMM & BiasAdd ) (0) | 2026.02.16 |
|---|---|
| 서로 다른 role 을 가지는 warp, (0) | 2025.12.16 |
| Warp-Specializaed Pipeline & cp.async Multi-Stage Overlap 개념 (0) | 2025.12.15 |
| GPU memory transaction, Byte 단위 사고 정리 (0) | 2025.12.15 |
| warp 에 이은 lane specialization?? (0) | 2025.12.15 |