본문 바로가기

dev_AI_framework

어느 단위 레벨까지 구현해야 할까 에 대한 고민 - 기본 레벨 연산들은 끝도 없

1) 레이어-우선 로드맵 (필수만)

이미 갖춘 것

  • ✅ GEMM(+Bias+Act), Conv2D(+bwd), LayerNorm/RMSNorm(+bwd), Softmax(+bwd), CrossEntropy(+bwd)
  • ✅ Dropout(+bwd), SDPA(FWD), Pool2D(Max/Ave, +bwd)
  • ✅ Ewise(활성/산술), Reduction(max/min 등), Slice/Concat

이걸로 MLP/Conv/Norm/Attention의 큰 줄기는 이미 가능합니다.

바로 다음 최소 세트

이 6가지만 추가하면 Transformer/ResNet-계열 학습·추론이 거의 다 돕니다.

  1. Embedding / EmbeddingBag
    • 필요 커널: gather (FWD), scatter_add (BWD, grad wrt weight)
    • 난이도: 중(atomicAdd 필요), 범용 인덱싱 계열 중 가성비 최고.
  2. Padding (constant/replicate/reflect 중 constant만 우선)
    • 필요 커널: 간단 scatter/gather 1개
    • Conv/Pool/Attention mask 전처리에 필수. (반대로 crop도 같은 커널로 처리 가능)
  3. Permute/Transpose(2D/ND) – view가 불가할 때만 copy
    • 필요 커널: ND index 계산 복사 1개
    • 레이아웃 정리/contiguous 강제 시 꼭 필요.
  4. Expand(broadcast_to) – view만
    • 필요사항: 0-stride 허용(읽기 전용) + ewise/reduction이 0-stride를 처리
    • 커널 추가 X, 브로드캐스트 엔진 규칙만 통일.
  5. TopK (values, indices) – 어텐션/빔서치/리랭킹 등
    • 축별 블록 TopK(소~중 k)부터. (대형 정렬은 후순위)
  6. Upsample2D (nearest/bilinear, FWD+BWD)
    • Vision 모델 다수에 필요. AlignCorners 옵션은 나중에.

보류(나중에): BatchNorm/GroupNorm(이미 LayerNorm 대체 가능), TransposedConv, CTC Loss, 고급 pad 모드(reflect/replicate), stable sort 등.


2) “기본 레벨”은 이렇게 스킵/대체

  • reshape/squeeze/unsqueeze/contiguous: 전부 메타 연산(view/스트라이드 갱신) + 필요 시 한 방 복사. 별도 런처 불필요.
  • logical/compare/where/isfinite: ewise 프레임에 op 테이블만 추가해서 일괄 처리.
  • cumsum/cumprod/argsort: 지금은 필수 아님. 나중에 학습 루프/통계에 꼭 필요해지면 도입.
  • adaptive pooling: Avg/Max 풀링 커널 재사용으로 나중에.

요지는 “모델 실행에 바로 영향 없는 기본기”는 잠깐 미루고, 레이어가 요구하는 최소 primitive만 뚫자는 것.


3) 레이어별 의존성 표 (최소 커널만)

레이어/기능필요한 커널(최소)
Linear/MLP ✅ GEMM(+bias/act) 완료
Conv2D (ResNet류) ✅ Conv2D(+bwd), ✅ Pad(constant)
Norm (LN/RMSNorm) ✅ 완료
Dropout ✅ 완료
Softmax + CE ✅ 완료
Multi-Head Attention ✅ SDPA(FWD), ➕ TopK(선택), ➕ Mask용 Pad/Expand
Embedding/EmbeddingBag ➕ Gather/ScatterAdd
Skip-connection & head ✅ ewise(+), concat/slice 완료
Upsample(UNet류) ➕ Upsample2D
Tensor layout 정리 ➕ Permute/Transpose, Expand(view)

4) 구현 우선순위 & 팁

  1. Gather / ScatterAdd
    • dtype: FP32(grad), idx: I32
    • 축 고정 버전부터(axis=1 같은) → 범용 ND 인덱스는 후속
    • scatter_add는 atomicAdd(FP32/FP16) 분기 필요
  2. Padding (constant)
    • FWD: out→in 좌표 역매핑, 범위 밖은 constant
    • BWD: in→out 범위 안의 grad만 누적
  3. Permute / Contiguous 복사
    • ND index 계산 헬퍼 1개로 통일
    • 큰 텐서는 coalesced 읽기/쓰기 우선(가능하면 innermost가 연속되도록)
  4. Expand(view)
    • shape·stride 컴퓨팅 규칙을 코어에 통일
    • ewise/reduce에서 0-stride 안전하게 사용(쓰기 금지)
  5. TopK(작은 k 우선)
    • 각 block 내 선택정렬/힙 + 병합(초판은 block-local로도 충분)
    • 반환: (values, indices), stable 옵션은 보류
  6. Upsample2D
    • nearest 먼저(FWD+BWD), bilinear는 다음
    • BWD는 기울기 분배(atomicAdd)로 간단

5) 테스트 전략(간단하지만 잘 잡히는 것)

  • Embedding: PyTorch 참조값과 FWD/BWD 비교(중복 인덱스 케이스 포함).
  • Pad: NCHW 랜덤 텐서로 FWD/BWD 비교(패딩 영역 grad=0 확인).
  • Permute: x → permute → contiguous → inverse-permute 동일성.
  • Expand: expand + ewise/reduction 조합(0-stride 안전성).
  • TopK: 동률값 포함 케이스, sorted=True/False 동작 점검.
  • Upsample: PyTorch와 FWD/BWD 비교(작은 텐서, 모서리/중앙 픽셀 체크).

결론

기본 레벨 전부를 구현하려 들지 말고, 레이어 구동에 필요한 최소 primitive(Gather/ScatterAdd, Pad, Permute/Expand, TopK, Upsample)만 치고 나가면, 지금 코드베이스로 Transformer/ResNet/UNet 대부분을 바로 커버할 수 있습니다.