본문 바로가기

dev_AI_framework

JAX - 함수의 수학적 객체 관찰, 분석 및 최적화

🔍 JAX vs PyTorch 개념적 차이 요약

개념PyTorchJAX (w/ XLA)
그래프 유형 동적 (define-by-run) 정적 (define-then-run via JAXPR)
미분 연산 중 각 노드에서 추적 함수 전체를 미분 함수로 변환
연산 최적화 연산자 수준의 최적화 (부분적) 전체 함수 단위 최적화 (XLA 컴파일)
디버깅 편의성 높음 (디버깅 쉬움) 낮음 (JAXPR이 생소함)
성능 최적화 유연성 낮음 (JIT 따로 없음) 높음 (jit, pmap 등 지원)

 

🔧 JAX의 기본 개념

JAX는 다음 두 가지 아이디어를 결합한 시스템입니다:

  1. NumPy처럼 동작하는 함수형 연산 (pure function style)
    → jax.numpy를 사용해서 NumPy처럼 코딩하지만 부작용 없는 함수로 구성
  2. 컴파일러 트랜스폼(transform)
    → grad, jit, vmap, pmap 등으로 함수를 변환해 그래프 최적화 및 병렬화 가능

 

🧠 "수학 함수처럼 정적 표현"이란?

의미:

  • 입력이 고정되면, 함수의 연산 그래프도 완전히 고정됨.
  • 동적 분기 (if, for 등)는 정적 그래프에서는 불변 조건으로 컴파일됨.
  • @jit, grad 등은 모두 정적인 함수 구조에만 적용될 수 있음.
def f(x):
    if x > 0:
        return x * 2
    else:
        return -x