2023.01.29 - [분류 전체보기] - implement_MLP(parameter_update)
implement_MLP(parameter_update)
MLP 모델에서 delta 값을 이용하여 각 노드들의 입력값의 변화에 따른 오차 함수의 변화량을 계산했고 delta 값을 통해 각 가중치의 변화량에 따른 오차 함수의 변화량을 구할 수 있었다. 이렇게 구
teach-meaning.tistory.com
임의의 가중치에 대한 오차 함수의 변화량을 통해 가중치 변화량을 계산할 수 있었다.
가중치 변화량을 적용하여 가중치를 수정해가면 MLP 연산을 수행한다.
for i in range(len(w)):
w_narray = np.array(w[i])
print(w_narray)
w_diff_narray = np.array(w_diff[i]).T
print(w_diff_narray)
new_w = w_narray - w_diff_narray
print(new_w)
b_narray = np.array(b[i])
print(b_narray, "b")
b_diff_narray = np.array(b_diff[i])
print(b_diff_narray)
new_b = b_narray - b_diff_narray
print(new_b)
>>>
[[0.1 0.2]
[0.3 0.4]
[0.5 0.6]]
[[0.00747633 0.00174744]
[0.01495265 0.00349488]
[0.02242898 0.00524232]]
[[0.09252367 0.19825256]
[0.28504735 0.39650512]
[0.47757102 0.59475768]]
[0.5] b
0.0046118836415391445
[0.49538812]
[[0.7 0.8]
[0.9 0.1]]
[[0.04047892 0.11298352]
[0.04166268 0.11628758]]
[[ 0.65952108 0.68701648]
[ 0.85833732 -0.01628758]]
[0.3] b
0.04913278948297279
[0.25086721]
단순하게 기존의 가중치에 가중치 변화량을 빼서 새로운 가중치를 구해보았다.
def w_parameter_update(w, w_diff, b, b_diff):
new_w_array = []
new_b_array = []
for i in range(len(w)):
w_narray = np.array(w[i])
w_diff_narray = np.array(w_diff[i]).T
new_w = w_narray - w_diff_narray
new_w_array.append(new_w.tolist())
b_narray = np.array(b[i])
b_diff_narray = np.array(b_diff[i])
new_b = b_narray - b_diff_narray
new_b_array.append(new_b.tolist())
return new_w_array, new_b_array
new_w, new_b = w_parameter_update(w, w_diff, b, b_diff)
new_w
>>>
[[[0.09252367330531104, 0.19825255941161068],
[0.28504734661062203, 0.39650511882322137],
[0.4775710199159331, 0.594757678234832]],
[[0.659521079677311, 0.6870164775333953],
[0.8583373245161284, -0.01628758361206635]]]
new_b
>>>
[[0.49538811635846086], [0.2508672105170272]]
'implement_ml_models > MLP' 카테고리의 다른 글
implement_MLP(parameter_update)(4 cost function) (0) | 2023.02.13 |
---|---|
implement_MLP(parameter_update)(3) (0) | 2023.02.11 |
implement_MLP(parameter_update) (0) | 2023.01.29 |
implement_DNN_biasupdate (1) | 2022.12.02 |
implement_DNN_generalization(1) (0) | 2022.12.02 |