Node.backpropagate(upstream_gradient=1.0)
이 함수는 backpropagation 중에 각 노드에서 자신의 입력값에 대한 비용 함수 변화량을 계산하고, 이를 자식 노드에게 전파하는 역할을 한다.
mean, sum 같은 특수 연산자는 직접 미분하지 않고 자식들에게 균등 분배한다.
일반 연산자에 대해서는,
- 저장된 input_value, output, weight_value 등을 기반으로 미분을 계산
- grad_input, grad_weight 를 계산
- grad_weight_total 에 누적 저장
- grad_input 을 자식 노드들에게 계속 전파
def backpropagate(self, upstream_gradient=1.0):
upstream_gradient : 이 노드의 출력이 다음 노드로부터 얼마나 민감하게 영향을 주는가
if self.operation in ["mean", "sum"]:
split_grad = upstream_gradient
if self.operation == "mean":
split_grad = upstream_gradient / len(self.children)
for child in self.children:
child.backpropagate(split_grad)
return
mean, sum : 자식 노드의 출력들을 단순히 합하거나 평균내는 구조,
따라서 역방향으로는 균등하게 나눠준다.
x = self.input_value if self.requires_input_value else 0.0
out = self.output if self.requires_output_value else 0.0
노드에 저장된 값 중 역전파 연산에 필요한 것만 가져옴
grad_input, grad_weight = self._gradient_func(self.operation)(x, self.weight_value, out, upstream_gradient)
연산자별 미분 규칙 정의된 곳에서 함수 호출,
x = 입력값, w = 가중치, out = 출력값
upstream_gradient = 현재 노드의 출력이 loss 에 미치는 영향
self.grad_weight_total += grad_weight
가중치에 대한 변화량 추적, 누적
for child in self.children:
child.backpropagate(grad_input)
이 노드의 입력값에 대한 gradient 를 자신의 자식 노드들에게 역으로 전달
# 미분 정의 (Backward)
@staticmethod
def _gradient_func(op):
return {
"add": lambda x, w, out, grad: (grad, grad),
"subtract": lambda x, w, out, grad: (grad, -grad),
"multiply": lambda x, w, out, grad: (grad * w, grad * x),
"divide": lambda x, w, out, grad: (
grad / (w if w != 0 else 1e-6),
-grad * x / ((w if w != 0 else 1e-6) ** 2),
),
"square": lambda x, w, out, grad: (2 * x * grad, 0.0),
"exp": lambda x, w, out, grad: (out * grad, 0.0),
"neg": lambda x, w, out, grad: (-grad, 0.0),
"reciprocal": lambda x, w, out, grad: (-1.0 / (x ** 2 if x != 0 else 1e-6) * grad, 0.0),
"const": lambda x, w, out, grad: (0.0, 0.0),
"mean": lambda x, w, out, grad: (grad, 0.0),
"sum": lambda x, w, out, grad: (grad, 0.0),
}[op]
_gradient_func() 딕셔너리에 정의된 각 연산자에 대한 역전파 공식
통일된 구조
lambda x, w, out, grad: (grad_input, grad_weight)
x : input 값
w : weight 값
out : output 값
grad : upstream gradient
반환 : grad_input - 입력값의 변화에 대한 비용 함수 변화량, grad_weight - 가중치 변화량에 대한 비용 함수 변화량 ( 가중치 갱신 시 해당 값이 사용된다.)
add
"add": lambda x, w, out, grad: (grad, grad),
z = a + b 일 경우
delta z / delta a = 1, delta z / delta b = 1
항상 gradient 를 그대로 전달
subtract
"subtract": lambda x, w, out, grad: (grad, -grad),
z = a - b
delta z / delta a = 1,
delta z / delta b = -1
감산의 방향성을 반영해 grad, -grad 로 전달
multiply
"multiply": lambda x, w, out, grad: (grad * w, grad * x),
z = x * w
delta z / delta x = w,
delta z / delta w = x
grad * w, grad * x
곱셈의 기본 미분 규칙, 입력값의 변화량의 경우 가중치와의 연산을, 가중치 변화량에 대해선 입력값과의 gradient 연산을 수행한다.
divide
"divide": lambda x, w, out, grad: (
grad / (w if w != 0 else 1e-6),
-grad * x / ((w if w != 0 else 1e-6) ** 2),
),
z = x / w
delta z / delta x = 1 / w
delta z / delta w = -x / w^2
0 으로의 나누는 것 방지
square
"square": lambda x, w, out, grad: (2 * x * grad, 0.0),
z = x^2
delta z / delta x = 2x
이 경우 가중치에 대한 연산이 아니므로 가중치 변화량에 대한 비용 함수 변화량은 0 이다.
exp
"exp": lambda x, w, out, grad: (out * grad, 0.0),
z = exp(x)
delta z / delta x = exp(x)
neg
"neg": lambda x, w, out, grad: (-grad, 0.0),
z = -x
delta z / delta x = -1
간단한 부호 반전
reciprocal
"reciprocal": lambda x, w, out, grad: (-1.0 / (x ** 2 if x != 0 else 1e-6) * grad, 0.0),
z = 1 / x
delta z / delta x = -1 / x^2
분모 제곱이 있는 역함수 미분
const : 상수 미분값은 항상 0
'dev_AI_framework' 카테고리의 다른 글
역전파 아이디어 계산 그래프 단축? (0) | 2025.04.30 |
---|---|
역전파 사고 흐름 정리 (0) | 2025.04.30 |
계산 그래프의 노드 정보를 어디까지 저장할지 (0) | 2025.04.27 |
비용 함수 계산 그래프 부분 문제 해결 (0) | 2025.04.26 |
비용 함수 값 연산 및 계산 그래프 생성 부분 (0) | 2025.04.25 |