🔍 JAX vs PyTorch 개념적 차이 요약
개념PyTorchJAX (w/ XLA)
그래프 유형 | 동적 (define-by-run) | 정적 (define-then-run via JAXPR) |
미분 | 연산 중 각 노드에서 추적 | 함수 전체를 미분 함수로 변환 |
연산 최적화 | 연산자 수준의 최적화 (부분적) | 전체 함수 단위 최적화 (XLA 컴파일) |
디버깅 편의성 | 높음 (디버깅 쉬움) | 낮음 (JAXPR이 생소함) |
성능 최적화 유연성 | 낮음 (JIT 따로 없음) | 높음 (jit, pmap 등 지원) |
🔧 JAX의 기본 개념
JAX는 다음 두 가지 아이디어를 결합한 시스템입니다:
- NumPy처럼 동작하는 함수형 연산 (pure function style)
→ jax.numpy를 사용해서 NumPy처럼 코딩하지만 부작용 없는 함수로 구성 - 컴파일러 트랜스폼(transform)
→ grad, jit, vmap, pmap 등으로 함수를 변환해 그래프 최적화 및 병렬화 가능
🧠 "수학 함수처럼 정적 표현"이란?
의미:
- 입력이 고정되면, 함수의 연산 그래프도 완전히 고정됨.
- 동적 분기 (if, for 등)는 정적 그래프에서는 불변 조건으로 컴파일됨.
- @jit, grad 등은 모두 정적인 함수 구조에만 적용될 수 있음.
def f(x):
if x > 0:
return x * 2
else:
return -x
'dev_AI_framework' 카테고리의 다른 글
계산 그래프 컴파일러 존재의 필요성 (0) | 2025.04.30 |
---|---|
역전파를 위한 계산 그래프의 행렬 변환 - 아이디어 구체화 (0) | 2025.04.30 |
역전파 아이디어 계산 그래프 단축? (0) | 2025.04.30 |
역전파 사고 흐름 정리 (0) | 2025.04.30 |
backpropagate() 각 연산 별 미분 값 ( delta 연산을 위한 입력값의 변화량에 대한 비용 함수 변화량, 가중치 갱신을 위한 가중치 변화량에 대한 비용함수 (0) | 2025.04.30 |