💡 현재 역전파 방식의 구조
- 각 노드는 자식 노드를 가지고 있고,
- 연산자 종류에 따라 각각의 자식으로부터 값을 받고,
- 수동으로 재귀적으로 backward 순회를 수행하며,
- 각 노드의 .grad_input, .grad_weight 등이 업데이트됨.
🔧 단점
- 순회 비용이 큼 (재귀 호출, 의존성 따라 순서 결정)
- 메모리 상 위치가 불규칙함 (캐시 효율 ↓)
🚀 당신의 아이디어: 고정된 행렬 기반 그래프 표현
계산 그래프를 행렬로 표현해서, 역전파도 행렬 곱 또는 인덱싱 연산으로 처리한다면?
✅ 장점
- 고정된 메모리 구조 → 메모리 접근 속도 향상 (캐시 친화적)
- 행렬 기반 연산은 GPU에서 병렬화가 용이
- 연산 순서를 동적으로 탐색하지 않아도 됨
- grad 값의 위치 = 곧 업데이트 대상 → 포인터/참조 이동 최소화
🔬 실현 가능성
이 아이디어는 크게 두 가지 방향으로 확장될 수 있습니다:
1. 그래프 구조 → 인접행렬 or 스파스 매트릭스 변환
- 노드 간의 연결을 인접 행렬로 저장
- 연산자 종류를 코드 번호로 매핑 (예: add=0, mul=1)
- 각 노드의 grad 값을 하나의 벡터로 저장
기존 frameworks 의 계산 그래프 최적화 전략
📈 2. 공통적인 최적화 전략
최적화 기법설명
Operation Fusion | 여러 연산을 하나로 결합하여 커널 호출 수 감소 → GPU 효율성 향상 |
Memory Reuse | 중간 결과 캐시 재활용, 필요 없는 노드 제거 |
Static Scheduling | 연산 순서를 컴파일 타임에 결정해서 런타임 오버헤드 제거 |
Gradient Accumulation | 동일 연산에서 반복적으로 쓰이는 gradient를 한 번에 처리 |
Sparse Optimization | 희소성(Sparsity)을 이용한 연산 최적화 (ex. Pruned 모델) |
Common Subexpression Elimination | 동일 연산 중복 제거 |
Buffer Coalescing | 메모리 접근 위치를 인접하게 하여 캐시 효율을 향상 |
Dead Node Elimination | 미사용 결과를 출력하는 노드 제거 |
'dev_AI_framework' 카테고리의 다른 글
역전파를 위한 계산 그래프의 행렬 변환 - 아이디어 구체화 (0) | 2025.04.30 |
---|---|
JAX - 함수의 수학적 객체 관찰, 분석 및 최적화 (0) | 2025.04.30 |
역전파 사고 흐름 정리 (0) | 2025.04.30 |
backpropagate() 각 연산 별 미분 값 ( delta 연산을 위한 입력값의 변화량에 대한 비용 함수 변화량, 가중치 갱신을 위한 가중치 변화량에 대한 비용함수 (0) | 2025.04.30 |
계산 그래프의 노드 정보를 어디까지 저장할지 (0) | 2025.04.27 |