import math
class Node:
valid_operations = {
"add", "subtract", "multiply", "divide", "square", "exp", "neg", "reciprocal"
}
def __init__(self, operation, input_value=0.0, weight_value=0.0, output=0.0, bias=0.0):
if operation not in self.valid_operations:
raise ValueError(f"Invalid operation: {operation}. Allowed: {self.valid_operations}")
self.operation = operation
self.input_value = input_value
self.weight_value = weight_value
self.output = output
self.bias = bias
self.grad_weight_total = 0.0
self.parents = []
self.children = []
def add_parent(self, parent):
if parent not in self.parents:
self.parents.append(parent)
parent.add_child(self)
def add_child(self, child):
if child not in self.children:
self.children.append(child)
child.add_parent(self) if self not in child.parents else None
def remove_parent(self, parent):
if parent in self.parents:
self.parents.remove(parent)
if self in parent.children:
parent.children.remove(self)
def remove_child(self, child):
if child in self.children:
self.children.remove(child)
if self in child.parents:
child.parents.remove(self)
def compute(self):
inputs = [p.output for p in self.parents]
self.output = self._operation_func(self.operation)(inputs, self.weight_value, self.bias)
return self.output
def backpropagate(self, upstream_gradient=1.0):
grad_input, grad_weight = self._gradient_func(self.operation)(
self.input_value, self.weight_value, self.output, upstream_gradient
)
self.grad_weight_total += grad_weight
for parent in self.parents:
parent.backpropagate(grad_input)
def update_weights(self, learning_rate):
self.weight_value -= learning_rate * self.grad_weight_total
self.grad_weight_total = 0.0
def find_leaf_nodes(self):
leaf_nodes, visited = [], set()
def dfs(node):
if node in visited:
return
visited.add(node)
if not node.children:
leaf_nodes.append(node)
for child in node.children:
dfs(child)
dfs(self)
return leaf_nodes
def print_tree(self, depth=0, visited=None):
if visited is None:
visited = set()
if self in visited:
print(" " * depth + f"↳ Node({self.operation}) (already visited)")
return
visited.add(self)
print(" " * depth + f"Node({self.operation}) → output={self.output}, weight={self.weight_value}, grad_total={self.grad_weight_total}")
for child in self.children:
child.print_tree(depth + 2, visited)
def __hash__(self):
return id(self)
def __eq__(self, other):
return id(self) == id(other)
@staticmethod
def _operation_func(op):
return {
"add": lambda inputs, w, b: sum(inputs) + b,
"subtract": lambda inputs, w, b: inputs[0] - inputs[1] + b,
"multiply": lambda inputs, w, b: inputs[0] * w + b,
"divide": lambda inputs, w, b: inputs[0] / (w if w != 0 else 1e-6) + b,
"square": lambda inputs, w, b: inputs[0] ** 2 + b,
"exp": lambda inputs, w, b: math.exp(inputs[0]) + b,
"neg": lambda inputs, w, b: -inputs[0] + b,
"reciprocal": lambda inputs, w, b: 1.0 / (inputs[0] if inputs[0] != 0 else 1e-6) + b,
}[op]
@staticmethod
def _gradient_func(op):
return {
"add": lambda x, w, out, grad: (grad, grad),
"subtract": lambda x, w, out, grad: (grad, -grad),
"multiply": lambda x, w, out, grad: (grad * w, grad * x),
"divide": lambda x, w, out, grad: (
grad / (w if w != 0 else 1e-6),
-grad * x / ((w if w != 0 else 1e-6) ** 2),
),
"square": lambda x, w, out, grad: (2 * x * grad, 0.0),
"exp": lambda x, w, out, grad: (out * grad, 0.0),
"neg": lambda x, w, out, grad: (-grad, 0.0),
"reciprocal": lambda x, w, out, grad: (-1.0 / (x ** 2 if x != 0 else 1e-6) * grad, 0.0),
}[op]
__hash__, __eq__ 정의 : Node 객체를 set, dict 에서 안전하게 사용
쌍방 관계 명확화 : add_parent, add_child 의 자동 연결 관계 보완
operation/gradient 함수 : 클래스 속성 대신 @staticmethod 로 명확하게 분리
'dev_AI_framework' 카테고리의 다른 글
비용 함수 값 연산 및 계산 그래프 생성 부분 (0) | 2025.04.25 |
---|---|
계산 그래프 구성을 위한 반환 값 수정, root_node_list, leaf_node_list (0) | 2025.04.20 |
CUDA 기반 loss function 계산 (0) | 2025.04.11 |
CUDA 연산 모듈 (.pyd) 이 Python 에서 import 되지 않음 (0) | 2025.04.10 |
디렉토리 구조 리팩토링 (0) | 2025.04.08 |