본문 바로가기

dev_AI_framework

연산 방식의 변경 - 임시 버퍼/전치 커널 , 커스텀 matmul backward 제거

1. 전치/커스텀 커널 제거 cuBLAS GEMM 으로 단일화

  • W^T, input_T 를 위한 임시 버퍼/전치 커널 모두 제거
  • 커스텀 matmul_backward 커널 제거
  • 대신 cuBLAS 로
    • d X = d Y @ W^T
    • d W += X^T @ d Y (배치 누적, beta = 1)
  • row-major <-> column-major 매핑을 감추는 얇은 래퍼 추가
  • cuBLAS 핸들 도입
  • GEMM 호출 형태 (중요 수식/치수)
    • 텐서 레이아웃은 row-major 기준, leading demension 은 열 개수
  • 메모리 관리 간소화/성능 개선
    • 전치용 버퍼 할당/해제 제거
    • d W 누적을 GEMM의 beta = 1 로 처리 - 별도 합산 커널 불필요
    • d X 버퍼는 배치 크기만큼 한 번만 cudaMalloc
  • ADD(편향) 역전파 경로 정리
    • dX = dY 는 경량 커널 add_backward_input 로 처리(원소 복사)
    • dB = sum_rows(dY) 는 add_backward_bias 로 배치별 부분합을 만들고, 배치 루프에서 add_inplace로 누적