본문 바로가기

dev_AI_framework

optimizer, model.compile 구현을 위해..

dev/optimizers/__init__.py

from dev.optimizers import SGD


ALL_OPTIMIZERS = {
    SGD,
}

ALL_OPTIMIZERS_DICT = {cls.__name__.lower(): cls for cls in ALL_OPTIMIZERS}

def get(identifier):
    if isinstance(identifier, str):
        obj = ALL_OPTIMIZERS_DICT.get(identifier, None)

    return obj

dev/optimizers/sgd.py

class SGD():
    def __init__(
        self,
        learning_rate = 0.01,
        momentum = 0.0,
        name = "SGD"
    ):
        self.momentum = momentum

    # optimizer 변수 초기화
    def build(self, variables):
        pass

    # 주어진 그래디언트와 모델 변수에 대해 업데이트 수행
    def update_step(self, gradient, variable, learning_rate):
        pass

    def get_config(self):
        pass

해당 인스턴스가 저장이 되도록