🎯 아이디어의 문제의식과 출발점
✅ 문제: 역전파 연산의 비효율성
- 계산 그래프 기반의 역전파는 노드 단위로 수행됨
→ 재귀적 순회, 의존성 추적, 조건 분기 등으로 인해 구조가 동적이고 비효율적 - 연산 단위가 작고 분산되어 있음
→ CUDA에서는 커널 호출이 많아지고, 각 커널이 작은 계산만 수행하게 되어 GPU 자원 낭비 - **행렬 연산(CUDA의 강점)**을 잘 활용하지 못함
→ 특히 역전파는 non-uniform한 데이터 흐름으로 인해 병렬 최적화 어려움
🔍 아이디어의 핵심 구체화
1. 계산 그래프 → 행렬 형태로 표현
- 일반 계산 그래프는 트리 또는 DAG로 구성되지만,
- 당신은 이를 명시적으로 행렬로 표현하여 다음을 이끌어내려는 것:
- ✅ 연산 순서의 정렬
- ✅ gradient 흐름의 위치 고정
- ✅ 메모리 상 효율적인 데이터 참조
2. 행렬의 특수 구조 활용 (sparse, upper triangular 등)
- 역전파는 기본적으로 출력 → 입력 방향으로 진행되므로,
- 연결 관계는 상삼각 행렬(upper triangular) 또는
- 희소한(많은 0) sparse matrix 형태를 띨 수 있음
- 이 구조를 유지하면:
- gradient 계산 시 불필요한 연산 회피 가능
- 업데이트 위치가 고정되므로 in-place 연산 최적화 가능
🧠 이 방식의 근본적 전환점
기존 방식새로운 방식 (당신의 아이디어)
계산 노드 단위 역전파 (node.backward()) | 행렬 전체 단위 역전파 (grad_input = M.T @ grad_out) |
CUDA 커널 = 여러 작은 연산자 실행 | CUDA 커널 = 단일 대규모 행렬 곱 |
순회 중간중간 연산 수행 | 연산을 모두 모은 뒤, 벡터화된 연산으로 처리 |
memory access = 연결 따라 분산됨 | memory access = 고정 위치 (캐시 효율 ↑) |
'dev_AI_framework' 카테고리의 다른 글
계산 그래프의 구조적 패턴 인식 + 연산 최적화 컴파일러 개념 (0) | 2025.04.30 |
---|---|
계산 그래프 컴파일러 존재의 필요성 (0) | 2025.04.30 |
JAX - 함수의 수학적 객체 관찰, 분석 및 최적화 (0) | 2025.04.30 |
역전파 아이디어 계산 그래프 단축? (0) | 2025.04.30 |
역전파 사고 흐름 정리 (0) | 2025.04.30 |