본문 바로가기

dev_AI_framework

Cross-Entropy 모듈 추가 / 수정

2025.09.22 - [dev_AI_framework] - Cross-Entropy (CUDA)

 

Cross-Entropy (CUDA)

개요Cross-Entropy 손실은 분류(classification) 문제에서 널리 사용되는 목적 함수다.모델 출력 logits와 타깃 인덱스 벡터를 입력으로 받아, 소프트맥스 확률을 통해 손실을 계산한다.현재 구현은 CUDA f32

teach-meaning.tistory.com

무엇이 바뀌었나 (요약)

  1. ignore_index 완전 지원
    • reduction="mean"일 때 평균의 분모를 전체 M이 아니라 **유효 샘플 수(Meff)**로 정확히 나누도록 수정.
    • FWD/BWD 모두 항상 target을 Host로 복사해 Meff = count(t != ignore_index)를 계산.
    • Meff == 0이면 FWD는 0으로 안전하게 처리, BWD는 inv_scale = 0로 하여 그라드 전체 0.
  2. Label Smoothing(ls_eps) 추가
    • 커널 런처 시그니처에 ls_eps(0이면 비활성) 추가.
    • FWD/BWD에서 q = (1−ε)*one_hot + ε/N를 반영하도록 구현.
  3. 런처와 커널 인터페이스 정리(일치)
    • 선언/정의/호출 완전 동일 시그니처로 정리:
      • FWD: ce_forward_logits_kernel_launcher(X, T, loss_vec, M, N, ignore_index, ls_eps, stream)
      • BWD: ce_backward_logits_kernel_launcher(X, T, dX, M, N, inv_scale, ignore_index, ls_eps, stream)
  4. 정밀 검증 추가
    • 텐서 검증 헬퍼: 장치/레이아웃/차원/shape/dtype(X,F32 / target,I32 / loss,F32) 체크 강화.
    • 파이썬 바인딩에서 target dtype을 int32로 강제(NumPy → int32 캐스팅)하여 타입 혼선 제거.
  5. 수치 안정성 유지
    • 커널에서 log-sum-exp 경로 사용(row_max 빼고 exp 합 → log 복원)으로 안정적.

 

API / 자료 구조

C++ Attributes

enum class Reduction : int { None = 0, Mean = 1, Sum = 2 };

struct CrossEntropyAttrs {
  bool   from_logits{true};   // (현재 구현은 logits 경로)
  Reduction reduction{Reduction::Mean};
  int    ignore_index{-1};    // -1이면 보통 무시 인덱스로 사용
  float  eps{1e-9f};          // from_logits=false일 때 log(p+eps)용 (보류)
  float  ls_eps{0.f};         // label smoothing 계수 (0=off)
};

런처 ( HOST ) 시그니처

Status CrossEntropyCudaLaunch(const Tensor& X, const Tensor& target, Tensor& loss,
                              const CrossEntropyAttrs& attrs, StreamHandle stream);

Status CrossEntropyCudaBackwardLaunch(const Tensor& X, const Tensor& target, Tensor& dX,
                                      const CrossEntropyAttrs& attrs, StreamHandle stream);

커널 런처 ( Wrapper ) 시그니처

// Forward: per-sample loss를 loss_vec[M]에 작성
void ce_forward_logits_kernel_launcher(const float* X, const int32_t* T,
                                       float* loss_vec, int M, int N,
                                       int ignore_index, float ls_eps,
                                       cudaStream_t);

// Backward: dX 계산 (inv_scale는 호스트에서 1/Meff 또는 1)
void ce_backward_logits_kernel_launcher(const float* X, const int32_t* T,
                                        float* dX, int M, int N,
                                        float inv_scale,
                                        int ignore_index, float ls_eps,
                                        cudaStream_t);

입출력/제약

  • 입력
    • X: [M,N] float32, RowMajor, CUDA
    • target: [M] int32, RowMajor, CUDA (주의: I64 미지원 → 바인딩에서 int32로 캐스팅)
  • 출력
    • Forward:
      • reduction=None → loss: [M]
      • reduction=Mean|Sum → loss: [1]
    • Backward:
      • dX: [M,N] float32 (X와 동일 shape)
  • 검증 실패시
    • 명확한 코드(Invalid, ShapeMismatch, DtypeMismatch 등)로 Status 반환 → 상위 ops는 음수 리턴.

수식/동작

Forward (from logits, label smoothing/ignore 포함)

  1. 확률 p = softmax(X) 는 커널에서 안정적으로 계산
  2. per-sample loss:
    • q = (1−ε)·one_hot(t) + ε/N
    • 무시 샘플(t == ignore_index)은 loss=0
    • 유효 샘플은 −∑_c q_c log p_c (one-hot이면 −log p_t)
  3. Reduction:
    • None: per-sample 벡터 그대로
    • Sum: per-sample 합
    • Mean: 유효 샘플 수 Meff로 나눔 (무시 샘플 제외). Meff==0 → 0

Backward

  • dX = (p − q) * scale
    • scale = 1/Meff (Mean), 1 (Sum/None)
    • 무시 샘플 row는 전부 0
    • label smoothing이면 q = (1−ε)·one_hot + ε/N

구현 포인트(수정 핵심)

  • FWD/REDUCE 단계와 BWD/inv_scale 계산에서 항상 Host로 target을 복사하여
    Meff = count(t != ignore_index) 을 산정 (부호 무관).
  • FWD Mean: out = sum(loss_vec) / Meff (Meff>0 ? Meff : 1)
    (Meff==0이면 loss_vec도 전부 0이므로 결과 0)
  • BWD Mean: inv_scale = (Meff>0) ? 1/Meff : 0
  • 커널에서는 t == ignore_index를 만나면 해당 row:
    • FWD: loss 0
    • BWD: dX 0

테스트 시나리오 (모두 True)

  • forward(mean), backward(mean)
  • forward(none), forward(sum)
  • forward(mean, ignore_index) ✅
  • backward(mean, ignore_index) ✅
  • forward/ backward(mean, label_smoothing) ✅

흔한 실수/주의사항

  • 타겟 dtype: 내부는 int32만 지원 → 파이썬 바인딩에서 np.int32로 변환.
  • ignore_index의 값 부호와 무관하게 Meff를 항상 타겟에서 세야 함.
  • Meff==0일 수 있음(모두 무시) → FWD는 0, BWD는 0스케일.