본문 바로가기

GPU-KERNEL

SMEM 에 대한 접근 - 저장소가 아닌 연산 스케줄의 일부, layout 이 알고리즘 그 자체

  • FlashAttention (대부분의 고성능 ops) 은 연산이 아니라 데이터가 언제/어떤 형태로/누가/어디서 접근하느냐가 성능을 결정한다
  • 그걸 결정하는 게 shared memory layout + 그에 맞춘 thread/warp 매핑이다.
  • 그래서 layout 이 바뀌면 접근 패턴, 동기화, 파이프라인 깊이, 수치 알고리즘 구현 방식까지 함께 바뀐다.

 

1) smem 을 캐시라고 생각하면 왜 실패하나

L1/L2 캐시는 하드웨어가 알아서

  • 어떤 라인이 들어오고,
  • 누가 공유하고
  • 충돌을 어느 정도 숨겨주고
  • 실패해도 그냥 느려질 뿐

하지만 smem 은 직접 설계해야 한다

  • 어떤 데이터가 smem 에 들어오는가
  • 언제 올리는가
  • 누가 쓰는가
  • 어떤 stride 로 접근하는가?
  • 얼마나 오래 들고 있는가
  • 언제 덮어쓰는가

저장소가 아닌 스케줄 테이블이다.

 

2) layout 이 알고리즘인 이유, 바꾸면 수학은 같아도 실행이 달라짐

FlashAttention 의 본질은

  1. S = QK^T
  2. softmax(S)
  3. O += P*V
  4. 마지막에 normalize

여기서 성능/구현 난이도를 결정하는 건 수식이 아닌

  •  Q / K / V 를 smem 에 어떤 형태로 배치하느냐
  • warp 들이 그걸 어떤 순서로 소비하느냐
  • softmax 의 intermediate 를 어디에 두고 누가 읽느냐

 

예시 1) K 를 transpose + pad 하는 순간 알고리즘이 변한다

K 를 글로벌에서 그대로 읽으면

  • warp  가 K 를 읽을 때 stride 가 꼬여서 smem bank conflict / 비정렬 load / 재사용 감소가 생김

K 를 smem 에 transpose 해서 저장하면

  • 이후 모든 접근이 행 방향 연속이 되어
  • load 벡터화 + bank conflict 제거 + 재사용 상승이 가능해짐

pad 를 넣으면

  • 특정 stride 에서 발생하는 정기적 bank aliasing 자체가 사라짐

이건 단순 최적화가 아니라

  • 이 타일을 어떤 비용 구조로 곱할 것인가 라는 연산 방법이 바뀐 것

 

예시 2) exp(score) 를 smem 에 half 로 저장하는 순간 softmax 구현이 바뀐다.

exp 를 매번 재계산하면

  • 연산량 증가 
  • 하지만 smem 저장 / 로드는 줄어듦

smem 에 half 로 저장할 시

  • warp1 이 V 누적할 때 재계산 없이 바로 사용 가능
  • 대신 smem 트래픽 / 용량 / bank 패턴이 요구됨
  • softmax stat 업데이트와 동기화 시점이 고정됨

즉 exp 를 어디에 두냐는 선택은

  • online softmax 의 데이터 플로우를 결정
  • warp specializtion 가능/불가능을 결정
  • barrier 필요 여부를 결정

 

3) FlashAttention 에서 smem layout 이 스케줄을 결정하는 3가지 축

A. 접근 패턴이 warp 역할을 강제한다.

  • warp0 가 score+stats 를 만든다면, 그 결과를 warp1 이 소비해야 함
  • 그 소비 위치가 reg 인지 smem 인지에 따라
    • 통신 방식
    • barrier 필요 여부
    • 파이프라인 깊이 가 고정된다.

B. stageing (2, 3 - stage) 이 smem 주소 계획으로 굳어진다

pipeline 을 한다는 건 결국

  • smem_stage[s] 에 load
  • smem_stage[s-1] 로 compute
  • 다음 iteration 에 overwirtie 이 덮어쓰기 규칙이 layout 에 박힌다.

그래서 stage 수를 바꾸면

  • smem footprint 가 변하고
  • 주소 계산이 바뀌고
  • bank conflict 패턴이 바뀌고
  • 레지스터 생존 기간도 바뀐다.

즉, 파이프라인으 옵션이 아니라 layout 과 한 몸이다.

c. 벡터화 단위가 thread mapping 을 결정한다

half2 로 K 를 저장하면

  • lane 들이 2-element 단위로 책임지게 되고
  • index 계산이 그 단위로 바뀌고
  • 일부 lane 은 쓸모 없는 일을 안 하게 재배치해야 한다.

그래서 벡터화는 단순히 더 빠르다가 아니라

  • thread responsibility 를 다시 짜는 것 = 알고리즘 재설계다

 

4) 그럼 어떻게 설계하냐 : 실전 가이드

Step 1) smem 에 올릴 데이터를 계산 단위로 정의해라

  • 타일에서 재사용되는 최소 단위가 뭔지 먼저 정한다

step 2) 그 단위를 warp 가 읽기 좋은 형태로 재배치해라

  • transpos/pad 는 대부분 여기서 나옴
  • 목표는 3 개
    • bank conflict 제거/완화
    • 연속 접근
    • warp 간 중복 로드 제거

Step 3) pipeline stage 를 정하고 smem 주소를 stage 로 분할해라

  • base + stage * stride 형태로 덮어쓰기 안전을 설계
  • stage 수는 숨기려는 latency 가 아니라 smem footprint 와 동기화 비용으로 결정

Step 4) intermediate 를 reg vs smem 중 어디에 둘 지 결정

  • reg 에 두면 : 빠름, 대신 reg pressuer 증가, occupancy 감소
  • smem 에 두면 : 통신 쉬움, 대신 smem traffic 증가, bank conflict 위험 증가

FlashAttention 에선 warp specialization 때문에 intermediate 를 smem 으로 보내는 선택이 자연스럽게 나오기도 함

 

5) AI framework ops 구현으로의 일반화

  • LayerNorm / Softmax / RMSNorm
    • online reduction + normalize
    • stats 를 어디에 두느냐가 알고리즘
  • GEMM epilogue 
    • output tile 을 smem 으로 staging 하느냐 vs 바로 store 하느냐
    • vectorized store 를 위해 layout 을 강제함
  • Conv2D
    • smem 에 올리는 건 입력이 아니라 사실상 타일링된 가공 데이터
    • im2col 을 명시적으로 만들지 않더라도 layout 이 im2ccol 역할을 함