dev_AI_framework

C++ 백엔드 수정, 리스트 입력이 아닌 numpy array 입력

명징직조지훈 2024. 9. 4. 09:11

기존에 작성한 코드는 C++ 에서 리스트 형태의 입력을 기대하고 있음, numpy array 로 전달하도록 해보자,

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <cmath>
#include <vector>

namespace py = pybind11;

py::array_t<double> relu(py::array_t<double> inputs) {
    py::buffer_info buf = inputs.request();
    double* ptr = static_cast<double*>(buf.ptr);

    py::array_t<double> result(buf.size);
    py::buffer_info buf_result = result.request();
    double* ptr_result = static_cast<double*>(buf_result.ptr);

    for (size_t i = 0; i < buf.size; ++i) {
        ptr_result[i] = ptr[i] > 0 ? ptr[i] : 0;
    }

    return result;
}

py::array_t<double> sigmoid(py::array_t<double> inputs) {
    py::buffer_info buf = inputs.request();
    double* ptr = static_cast<double*>(buf.ptr);

    py::array_t<double> result(buf.size);
    py::buffer_info buf_result = result.request();
    double* ptr_result = static_cast<double*>(buf_result.ptr);

    for (size_t i = 0; i < buf.size; ++i) {
        ptr_result[i] = 1.0 / (1.0 + std::exp(-ptr[i]));
    }

    return result;
}

py::array_t<double> tanh_activation(py::array_t<double> inputs) {
    py::buffer_info buf = inputs.request();
    double* ptr = static_cast<double*>(buf.ptr);

    py::array_t<double> result(buf.size);
    py::buffer_info buf_result = result.request();
    double* ptr_result = static_cast<double*>(buf_result.ptr);

    for (size_t i = 0; i < buf.size; ++i) {
        ptr_result[i] = std::tanh(ptr[i]);
    }

    return result;
}

py::array_t<double> leaky_relu(py::array_t<double> inputs, double alpha = 0.01) {
    py::buffer_info buf = inputs.request();
    double* ptr = static_cast<double*>(buf.ptr);

    py::array_t<double> result(buf.size);
    py::buffer_info buf_result = result.request();
    double* ptr_result = static_cast<double*>(buf_result.ptr);

    for (size_t i = 0; i < buf.size; ++i) {
        ptr_result[i] = ptr[i] > 0 ? ptr[i] : alpha * ptr[i];
    }

    return result;
}

py::array_t<double> softmax(py::array_t<double> inputs) {
    py::buffer_info buf = inputs.request();
    if (buf.ndim != 1)
        throw std::runtime_error("Input should be a 1-D array");

    size_t size = buf.size;
    double* ptr = static_cast<double*>(buf.ptr);

    py::array_t<double> result(size);
    py::buffer_info buf_result = result.request();
    double* ptr_result = static_cast<double*>(buf_result.ptr);

    double max_val = *std::max_element(ptr, ptr + size);
    double sum = 0.0;
    for (size_t i = 0; i < size; ++i) {
        ptr_result[i] = std::exp(ptr[i] - max_val);
        sum += ptr_result[i];
    }
    for (size_t i = 0; i < size; ++i) {
        ptr_result[i] /= sum;
    }

    return result;
}

입력과 출력의 형태..

py::array_t<double> : 데이터 형식이네

py::buffer_info : Pybind11에서 Numpy 배열과 같은 버퍼 기반 객체에 대한 메타데이터를 제공하는 구조체

inputs : Numpy 배열의 버퍼 정보를 가져온다.

buf.ptr 은 버퍼의 첫 번째 요소를 가리키는 포인터

py::array_t<double> result(buf.size) : 새로운 Numpy 배열 생성, 

py:buffer_info buf_result = result.request() : 생성된 result 배열의 버퍼 정보를 가져온다. 

double* ptr_result = static_cast<double*>(buf_result.ptr) : buf_result.ptr 을 double* 타입으로 캐스팅하여 결과 배열의 데이터를 조작할 수 있는 포인터를 얻는다.