본문 바로가기

dev_AI_framework

CUDA 커널 호출 구조 개선 - 한 번의 호출로 학습 완료!

지금 구조( forward → forward+loss → loss_bw+layer_bw )는 “잘게 나뉜” 파이프라인이라 이해하기 쉽지만, 런치/호스트-디바이스 왕복/메모리 할당 등에서 오버헤드가 크다.

 

이를 통합하여 사용할 수 있는 방법을 고민,

 

먼저 현재 python 에서 forward+loss -> backward -> optimizer_update 를 반복 호출하고 있음, CUDA 쪽에서 한 번에 처리하는 엔트리를 만들어 1회 호출로 사용할 수 있도록

// ✅ 한 번에 학습(Forward+Loss → Backward → Optimizer Update)
float train_step_entry(
    const std::vector<OpStruct>& E,
    const std::unordered_map<std::string, uintptr_t>& tensor_ptrs,
    const std::unordered_map<std::string, Shape>& shapes_in,
    const std::string& final_output_id,
    const std::string& label_tensor_id,
    const std::string& loss_type,
    int batch_size,
    // Optimizer
    OptimizerType opt_type,
    float lr = 0.01f,
    float beta1 = 0.9f,
    float beta2 = 0.999f,
    float eps = 1e-8f,
    int timestep = 1,
    // (선택) 옵티마이저 상태 버퍼들: 파라미터 이름 → 디바이스 포인터
    const std::unordered_map<std::string, uintptr_t>& velocity_ptrs = {},
    const std::unordered_map<std::string, uintptr_t>& m_ptrs = {},
    const std::unordered_map<std::string, uintptr_t>& v_ptrs = {})
{
    // 1) 텐서/셰이프 맵 준비
    std::unordered_map<std::string, float*> tensors;
    for (const auto& kv : tensor_ptrs)
        tensors[kv.first] = reinterpret_cast<float*>(kv.second);

    // const_cast: 내부 커널 시그니처 유지
    auto& shapes = const_cast<std::unordered_map<std::string, Shape>&>(shapes_in);

    // 2) Forward + Loss
    float loss = run_graph_with_loss_cuda(
        E, tensors, shapes, final_output_id, label_tensor_id, loss_type, batch_size);

    // 3) Backward (gradients에 파라미터/중간 그래디언트들이 채워짐)
    std::unordered_map<std::string, float*> gradients;
    run_graph_backward(E, tensors, shapes, gradients, final_output_id, batch_size);

    // 4) 학습 대상(파라미터) 집합 구성: E에서 param_id를 가진 연산만 추림
    std::set<std::string> trainable_params;
    for (const auto& op : E) {
        // 연산 종류별 파라미터 존재 여부
        const bool uses_param =
            (op.op_type == OpType::MATMUL) ||
            (op.op_type == OpType::ADD)    ||
            (op.op_type == OpType::CONV2D);
        if (uses_param && !op.param_id.empty())
            trainable_params.insert(op.param_id);
    }

    // 5) Optimizer update
    for (const auto& name : trainable_params) {
        // 파라미터/그래디언트 포인터와 사이즈 확인
        auto t_it = tensors.find(name);
        auto g_it = gradients.find(name);
        auto s_it = shapes.find(name);
        if (t_it == tensors.end() || g_it == gradients.end() || s_it == shapes.end()) {
            // 그래디언트가 없을 수 있음(해당 step에서 사용X 등) → 건너뜀
            continue;
        }

        float* param_ptr = t_it->second;
        float* grad_ptr  = g_it->second;
        const Shape& shp = s_it->second;
        int size = shp.rows * shp.cols;

        // 옵티마이저 상태 버퍼(없으면 0)
        uintptr_t v_ptr_u = 0, m_ptr_u = 0, vel_ptr_u = 0;

        // 수정 (C++14 호환)
        auto it_v = v_ptrs.find(name);
        if (it_v != v_ptrs.end()) v_ptr_u = it_v->second;

        auto it_m = m_ptrs.find(name);
        if (it_m != m_ptrs.end()) m_ptr_u = it_m->second;

        auto it_vel = velocity_ptrs.find(name);
        if (it_vel != velocity_ptrs.end()) vel_ptr_u = it_vel->second;

        optimizer_update_cuda(
            param_ptr,
            reinterpret_cast<const float*>(grad_ptr),
            reinterpret_cast<float*>(vel_ptr_u),
            reinterpret_cast<float*>(m_ptr_u),
            reinterpret_cast<float*>(v_ptr_u),
            lr, beta1, beta2, eps, size, opt_type, timestep
        );
    }

    // 6) 임시 grad 메모리 정리(중복 free 방지)
    std::unordered_set<float*> freed;
    for (const auto& kv : gradients) {
        float* p = kv.second;
        if (!p) continue;
        if (freed.insert(p).second) {
            cudaFree(p);
        }
    }

    return loss;
}