본문 바로가기

dev_AI_framework

backpropagate() 각 연산 별 미분 값 ( delta 연산을 위한 입력값의 변화량에 대한 비용 함수 변화량, 가중치 갱신을 위한 가중치 변화량에 대한 비용함수

Node.backpropagate(upstream_gradient=1.0)

이 함수는 backpropagation 중에 각 노드에서 자신의 입력값에 대한 비용 함수 변화량을 계산하고, 이를 자식 노드에게 전파하는 역할을 한다.

mean, sum 같은 특수 연산자는 직접 미분하지 않고 자식들에게 균등 분배한다.

일반 연산자에 대해서는,

  • 저장된 input_value, output, weight_value 등을 기반으로 미분을 계산 
  • grad_input, grad_weight 를 계산
  • grad_weight_total 에 누적 저장
  • grad_input 을 자식 노드들에게 계속 전파

 

def backpropagate(self, upstream_gradient=1.0):

upstream_gradient : 이 노드의 출력이 다음 노드로부터 얼마나 민감하게 영향을 주는가

 

    if self.operation in ["mean", "sum"]:
        split_grad = upstream_gradient
        if self.operation == "mean":
            split_grad = upstream_gradient / len(self.children)
        for child in self.children:
            child.backpropagate(split_grad)
        return

mean, sum : 자식 노드의 출력들을 단순히 합하거나 평균내는 구조,

따라서 역방향으로는 균등하게 나눠준다. 

 

    x = self.input_value if self.requires_input_value else 0.0
    out = self.output if self.requires_output_value else 0.0

노드에 저장된 값 중 역전파 연산에 필요한 것만 가져옴

 

    grad_input, grad_weight = self._gradient_func(self.operation)(x, self.weight_value, out, upstream_gradient)

연산자별 미분 규칙 정의된 곳에서 함수 호출,

x = 입력값, w = 가중치, out = 출력값

upstream_gradient = 현재 노드의 출력이 loss 에 미치는 영향

 

    self.grad_weight_total += grad_weight

가중치에 대한 변화량 추적, 누적

 

    for child in self.children:
        child.backpropagate(grad_input)

이 노드의 입력값에 대한 gradient 를 자신의 자식 노드들에게 역으로 전달

 

    # 미분 정의 (Backward)
    @staticmethod
    def _gradient_func(op):
        return {
            "add": lambda x, w, out, grad: (grad, grad),
            "subtract": lambda x, w, out, grad: (grad, -grad),
            "multiply": lambda x, w, out, grad: (grad * w, grad * x),
            "divide": lambda x, w, out, grad: (
                grad / (w if w != 0 else 1e-6),
                -grad * x / ((w if w != 0 else 1e-6) ** 2),
            ),
            "square": lambda x, w, out, grad: (2 * x * grad, 0.0),
            "exp": lambda x, w, out, grad: (out * grad, 0.0),
            "neg": lambda x, w, out, grad: (-grad, 0.0),
            "reciprocal": lambda x, w, out, grad: (-1.0 / (x ** 2 if x != 0 else 1e-6) * grad, 0.0),
            "const": lambda x, w, out, grad: (0.0, 0.0),
            "mean": lambda x, w, out, grad: (grad, 0.0),
            "sum": lambda x, w, out, grad: (grad, 0.0),
        }[op]

_gradient_func() 딕셔너리에 정의된 각 연산자에 대한 역전파 공식

 

통일된 구조

lambda x, w, out, grad: (grad_input, grad_weight)

x : input 값 

w : weight 값 

out : output 값

grad : upstream gradient

 

반환 : grad_input - 입력값의 변화에 대한 비용 함수 변화량, grad_weight - 가중치 변화량에 대한 비용 함수 변화량 ( 가중치 갱신 시 해당 값이 사용된다.)

 

add

"add": lambda x, w, out, grad: (grad, grad),

z = a + b 일 경우 

delta z / delta a = 1, delta z / delta b = 1

항상 gradient 를 그대로 전달

 

subtract

"subtract": lambda x, w, out, grad: (grad, -grad),

z = a - b

delta z / delta a = 1, 

delta z / delta b = -1 

감산의 방향성을 반영해 grad, -grad 로 전달

 

multiply

"multiply": lambda x, w, out, grad: (grad * w, grad * x),

z = x * w

delta z / delta x = w,

delta z / delta w = x

grad * w, grad * x

곱셈의 기본 미분 규칙, 입력값의 변화량의 경우 가중치와의 연산을, 가중치 변화량에 대해선 입력값과의 gradient 연산을 수행한다.

 

divide

"divide": lambda x, w, out, grad: (
    grad / (w if w != 0 else 1e-6),
    -grad * x / ((w if w != 0 else 1e-6) ** 2),
),

z = x / w

delta z / delta x = 1 / w

delta z / delta w = -x / w^2

0 으로의 나누는 것 방지

 

square

"square": lambda x, w, out, grad: (2 * x * grad, 0.0),

z = x^2

delta z / delta x = 2x

이 경우 가중치에 대한 연산이 아니므로 가중치 변화량에 대한 비용 함수 변화량은 0 이다. 

 

exp

"exp": lambda x, w, out, grad: (out * grad, 0.0),

z = exp(x)

delta z / delta x = exp(x)

 

neg

"neg": lambda x, w, out, grad: (-grad, 0.0),

z = -x

delta z / delta x = -1

간단한 부호 반전

 

reciprocal

"reciprocal": lambda x, w, out, grad: (-1.0 / (x ** 2 if x != 0 else 1e-6) * grad, 0.0),

z = 1 / x

delta z / delta x = -1 / x^2

분모 제곱이 있는 역함수 미분

 

const : 상수 미분값은 항상 0