본문 바로가기

implement_ml_models/MLP

implement_MLP(parameter_update)(2)

2023.01.29 - [분류 전체보기] - implement_MLP(parameter_update)

 

implement_MLP(parameter_update)

MLP 모델에서 delta 값을 이용하여 각 노드들의 입력값의 변화에 따른 오차 함수의 변화량을 계산했고 delta 값을 통해 각 가중치의 변화량에 따른 오차 함수의 변화량을 구할 수 있었다. 이렇게 구

teach-meaning.tistory.com

임의의 가중치에 대한 오차 함수의 변화량을 통해 가중치 변화량을 계산할 수 있었다.

가중치 변화량을 적용하여 가중치를 수정해가면 MLP 연산을 수행한다.

for i in range(len(w)):
  w_narray = np.array(w[i])
  print(w_narray)

  w_diff_narray = np.array(w_diff[i]).T
  print(w_diff_narray)

  new_w = w_narray - w_diff_narray
  print(new_w)

  b_narray = np.array(b[i])
  print(b_narray, "b")

  b_diff_narray = np.array(b_diff[i])
  print(b_diff_narray)

  new_b = b_narray - b_diff_narray
  print(new_b)
>>>
[[0.1 0.2]
 [0.3 0.4]
 [0.5 0.6]]
[[0.00747633 0.00174744]
 [0.01495265 0.00349488]
 [0.02242898 0.00524232]]
[[0.09252367 0.19825256]
 [0.28504735 0.39650512]
 [0.47757102 0.59475768]]
[0.5] b
0.0046118836415391445
[0.49538812]
[[0.7 0.8]
 [0.9 0.1]]
[[0.04047892 0.11298352]
 [0.04166268 0.11628758]]
[[ 0.65952108  0.68701648]
 [ 0.85833732 -0.01628758]]
[0.3] b
0.04913278948297279
[0.25086721]

단순하게 기존의 가중치에 가중치 변화량을 빼서 새로운 가중치를 구해보았다.

def w_parameter_update(w, w_diff, b, b_diff):
  new_w_array = []
  new_b_array = []

  for i in range(len(w)):
    w_narray = np.array(w[i])
    w_diff_narray = np.array(w_diff[i]).T

    new_w = w_narray - w_diff_narray
    
    new_w_array.append(new_w.tolist())

    b_narray = np.array(b[i])
    b_diff_narray = np.array(b_diff[i])
  
    new_b = b_narray - b_diff_narray

    new_b_array.append(new_b.tolist())

  return new_w_array, new_b_array
  
new_w, new_b = w_parameter_update(w, w_diff, b, b_diff)

new_w
>>>
[[[0.09252367330531104, 0.19825255941161068],
  [0.28504734661062203, 0.39650511882322137],
  [0.4775710199159331, 0.594757678234832]],
 [[0.659521079677311, 0.6870164775333953],
  [0.8583373245161284, -0.01628758361206635]]]
  
new_b
>>>
[[0.49538811635846086], [0.2508672105170272]]

 

'implement_ml_models > MLP' 카테고리의 다른 글

implement_MLP(parameter_update)(4 cost function)  (0) 2023.02.13
implement_MLP(parameter_update)(3)  (0) 2023.02.11
implement_MLP(parameter_update)  (0) 2023.01.29
implement_DNN_biasupdate  (1) 2022.12.02
implement_DNN_generalization(1)  (0) 2022.12.02