dev_AI_framework

optimizer의 C++ Backend 구현, Adam - Adaptive momentum estimation

명징직조지훈 2024. 9. 3. 15:25

Adaptive Momentum Estimation 은 모멘텀 최적화와 RMSProp 의 아이디어를 합친 것, 

모멘텀 최적화처럼 지난 그레이디언트의 지수 감소 평균 (exponential decaying average) 를 따르고 

RMSProp 처럼 지난 그레이디언트 제곱의 지수 감소된 평균을 따른다.

 

모멘텀과 적응적 학습률을 결합한 방식,

과거의 기울기 값을 지수적으로 가중합하여 기울기 계산

C++ 코드

class Adam {
public:
    Adam(double learning_rate, double beta1 = 0.9, double beta2 = 0.999, double epsilon = 1e-8)
        : learning_rate(learning_rate), beta1(beta1), beta2(beta2), epsilon(epsilon), t(0) {}

    std::vector<double> update(std::vector<double>& weights, const std::vector<double>& gradients) {
        t += 1;
        for (size_t i = 0; i < weights.size(); ++i) {
            m[i] = beta1 * m[i] + (1.0 - beta1) * gradients[i];
            v[i] = beta2 * v[i] + (1.0 - beta2) * gradients[i] * gradients[i];

            double m_hat = m[i] / (1.0 - std::pow(beta1, t));
            double v_hat = v[i] / (1.0 - std::pow(beta2, t));

            weights[i] -= learning_rate * m_hat / (std::sqrt(v_hat) + epsilon);
        }
        return weights;
    }

private:
    double learning_rate;
    double beta1;
    double beta2;
    double epsilon;
    int t;
    std::map<size_t, double> m; // 1st moment vector
    std::map<size_t, double> v; // 2nd moment vector
};