본문 바로가기

dev_AI_framework

행렬 곱 용어, 차원 표기 정의 - GEMM, (B, T, C, H, Dh) ...

1. GEMM (General Matrix-Matrix Multiplication)

  • alpha - 새 계산 값의 크기 조절
    • 스칼라 계수, 결과 스케일링에 사용
  • beta - 기존 C 값의 보존량
    • 스칼라 계수, 기존의 C 행렬에 곱해진 후 더해짐
    • beta = 0 - C 의 기존 값은 무시되고 덮어쓰기
    • beat = 1 - C 의 기존 값과 새 결과를 합산

 

2. gemm_rm_tf 의 네이밍 해석 - row-major 포맷 행렬들 사이에서 TF32 정밀도로 수행하는 GEMM

백엔드/프레임워크 내부 규현 규현 규약

  • rm -> row-major (행 우선 저장)
    • CPU 의 일반적인 배열 포맷
  • cm -> column-major
  • tf -> tensorFloat-32(TF32)
    • 연산 정밀도

 

Transformer 에서 자주 나오는 차원 표기

기존 심볼

  • B = Batch size (묶어서 학습하는 샘플 수)
  • T = Sequence length (토큰 길이, time step)
  • C = Embdding dimension (모델 차원)
  • H = Number of heads (multi-head attention 의 head 수)
  • Dh = Head dimension (각 head 의 벡터 차원)

 

입력 X 가 있다고 할 때:

  • X : (B, T, C) -> 문장 배치 (배치 32, 문장 길이 128, 임베딩 512 같은 구조)
  • Q, K, V 생성
    • 선형층 후 reshape : (B, T, C) -> (B, H, T, Dh)
    • 이제 각 head 마다 (T, Dh) 시퀀스를 따로 처리
  • Attention 가중치: (B, H, T, T)
  • Head merge 후: 다시 (B, T, C)