본문 바로가기

dev_AI_framework

run_graph_backward의 각 변수 해석

  • E: std::vector<Opstruct> : 순전파에 사용했던 계산 그래프의 연산 리스트
  • tensors: unordered_map<string, float*> : 모든 순전파 텐서의 디바이스 포인터 저장소 ( tensors["dense_W"], ... )
  • shapes: unordered_map<string, Shape> : tensors 각 항목의 (rows, cols) 를 보관, 역전파에서 메모리 크기 계산과 커널 그리드 구성의 근거
  • gradients : unordered_map<string, float*> : 각 노드의 출력/입력/파라미터에 대한 gradient 버퍼(디바이스 포인터)를 보관
    • gradients[op.output_id] 는 그 노드 출력에 대한 d L / d output
    • gradients[op.input_id] 는 그 노드 입력에 대한 d L / d input
    • gradients[op.parma_id] 는 파라미터 gradient
  • y_pred/y_true, sz, d L_d y
    • y_pred, y_true 는 라벨과예측
    • sz = rows * cols 는 LOSS 입력 텐서 크기
    • dL_dy 는 LOSS 의 첫 gradient 버퍼
  • 루프 내 변수들
    • op : 현재 연산 노드
    • input = tensors[op.input_id], parm = tensors[op.param_id]
    • grad_out = gradients[op.output_id] : 현재 노드가 받는 gradient
    • grad_input : 이 노드가 이전 노드로 전파할 gradient
    • in_shape/out_shape : 각 op.input_id, op.output_id 의 shape

 

연산별 역전파 요약

  • ACTIVATION 
    • grad_input = grad_out @ f (output) 커널에 tensors[op.output_id] 전달해서 도함수 계산
  • ADD(bias add)
    • grad_input = grad_out
    • grad_bias = reduce_rows
  • MATMUL (Dense)
    • 입력 gradient : d X = d Y @ W^T
    • 가중치 gradient : d W = X^T @ d Y