본문 바로가기

AI Compiler framework

GEMM Transpose 처리 관련 수정 전

오해했던 부분

transpose 의 처리를 커널에서 shared memory load 레이아웃만 바꿔서 처리하는 줄 알았음, 

Python 이 같은 텐서를 넘겨주면 C++ 이 stride / shape + trans 플래그로 논리 view 를 만들고 그 view 로 indexing 만 달리해서 계산한다.

 

  • 물리 저장 physical storage : 실제 메모리에 어떻게 깔려있는가 data pointer + stride
  • 논리 view, logical view : 연산에서 이 텐서를 rows, cols 행렬로 어떻게 해석할 것인가

핵심 선택

  • transpose 를 데이터를 뒤집어서 새로 만드는 것을 하지 않고
  • view 변환으로 처리한다

 

python 에서는 transpose 를 만들지 않는다.

GEMM 에 transA, transB 플래그만 넘김

연산 의미만 전달하고, 실제 저장 접근 방식 결정은 CUDA 쪽이 한다.

 

CUDA 에서의 역할 : MatView2D 로 통일해서 해석

CUDA 런타임에서 텐서는 TensorDesc(data, shape, stride) 로 들어온다.

여기서 stride 는 elements 단위이고, torch 의 view stride 도 그대로 보존된다.

GEMM launcher 는 이런 view 를 만든다.

MatView2D A = make_view_2d(descA, transA);
MatView2D B = make_view_2d(descB, transB);
MatView2D C = make_view_2d(descC, false);

make_view_2d(T, trans) 는

  • trans = False 이면
    • rows = shape[0], cols = shape[1]
    • rs = stride[0], cs = stride[1]
  • trans = True 이면 논리적으로 축만 바꿔서
    • rows = shape[1], cols = shape[0]
    • rs = stride[1], cs = stride[0]

data 는 그대로고, stride 는 해석만 바뀐다.

 

NN / NT / TN 은 커널을 바꾸는 게 아니라 view 만 다르다.

NN : transA =False, transB = False

  • A : M, K
  • B : K, N
  • indexing
    • A(m,k) = A.data
    • B(k,n) = B.data

TN : transA = True, transB = False

A 의 저장을 뒤집어서 새 배열을 만들지 않고, A 접근 공식만 바뀐다.

NT : transA = False, transB = True

동일하게 B 의 인덱싱 규칙만 변경

 

shared memory load 레이아웃 변경으로 transpose 처리도 가능한 구현 방식 중 하나지만

고성능 GEMM 에서는

  • 글로벌에서 load 할 때부터 transpose 해서 shared memory 에 넣거나
  • smem layout 을 바꿔 bacnk conflict 감소
  • WMMA load_matrix 의 row / col_major 요구사항 때문에 smem 을 재배치

같은 테크닉이 들어간다