본문 바로가기

dev_AI_framework

_ops_common 의 구현을 통한 공통 shim 타입 / 규약의 단일 진입점의 역할 수행

구현한 CUDA ops 바인딩과 Python  glue 사이의 공통 언어를 제공하는 모듈

 

CUDA 커널이 이해하는 ai::cuda::shim 타입들을 Python 에서 직접 생성 관리할 수 있도록 하는 shim bridge

 

모든 CUDA ops 바인딩은 shim 기반 타입을 쓴다.

ActKind, GemmAttrs, TensorDesc, Tensor, make_tensor_2d, Device, DType, Layout

이 타입들은 공통 규칙에 따라 C++ 에서 정의됨

위 shim 타입들을 Python 레벨에도 그대로 가져와준다. 

 

실제 CUDA ops 바인딩은 최소 기능만 유지, 즉 직접 생산되는 결과물만 가지고 있고 공통 타입은 정의하지 않음

ops 바인딩은

각 op 의 API 에만 집중, 타입 / 레이아웃의 검증은 shim 을 통해 일관 유지

 

// src/bindings/ops_common_pybind.cpp
//
// Common shim-based types for CUDA ops:
//   - ActKind, GemmAttrs
//   - Device, DType, Layout
//   - TensorDesc, Tensor, make_tensor_2d
//
// 이 모듈에서 노출하는 타입들은 모두
// backends/cuda/ops/_common/shim/* 에 정의된 ai::cuda::shim 기준입니다.
// CUDA ops 바인딩(_ops_gemm, _ops_conv2d, ...)과 동일한 타입을 사용하게 해서
// Python → shim Tensor → 각 ops 바인딩 으로 그대로 전달됩니다.

#include <cstdint>
#include <vector>
#include <stdexcept>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "backends/cuda/ops/_common/shim/ai_shim.hpp"

namespace py   = pybind11;
namespace shim = ::ai::cuda::shim;

PYBIND11_MODULE(_ops_common, m) {
    m.attr("__package__") = "graph_executor_v2.ops";
    m.doc() = R"(Common shim-based types for CUDA ops:
- ActKind, GemmAttrs
- Device, DType, Layout
- TensorDesc, Tensor, make_tensor_2d)";

    // ------------------------------------------------------------------
    // ActKind, GemmAttrs (shim:: 기준)
    // ------------------------------------------------------------------
    py::enum_<shim::ActKind>(m, "ActKind", py::arithmetic(), py::module_local(false))
        .value("None",      shim::ActKind::None)
        .value("ReLU",      shim::ActKind::ReLU)
        .value("LeakyReLU", shim::ActKind::LeakyReLU)
        .value("GELU",      shim::ActKind::GELU)
        .value("Sigmoid",   shim::ActKind::Sigmoid)
        .value("Tanh",      shim::ActKind::Tanh)
        .export_values();

    py::class_<shim::GemmAttrs>(m, "GemmAttrs", py::module_local(false))
        .def(py::init<>())
        .def_readwrite("trans_a",     &shim::GemmAttrs::trans_a)
        .def_readwrite("trans_b",     &shim::GemmAttrs::trans_b)
        .def_readwrite("act",         &shim::GemmAttrs::act)
        .def_readwrite("with_bias",   &shim::GemmAttrs::with_bias)
        .def_readwrite("leaky_slope", &shim::GemmAttrs::leaky_slope)
        .def_readwrite("save_z",      &shim::GemmAttrs::save_z);

    // ------------------------------------------------------------------
    // Device / DType / Layout (shim enums)
    // ------------------------------------------------------------------
    py::enum_<shim::Device>(m, "Device", py::arithmetic(), py::module_local(false))
        .value("CPU",  shim::Device::CPU)
        .value("CUDA", shim::Device::CUDA)
        .export_values();

    py::enum_<shim::DType>(m, "DType", py::arithmetic(), py::module_local(false))
        .value("F32",  shim::DType::F32)
        .value("F16",  shim::DType::F16)
        .value("BF16", shim::DType::BF16)
        .value("I32",  shim::DType::I32)
        .value("I8",   shim::DType::I8)
        .export_values();

    py::enum_<shim::Layout>(m, "Layout", py::arithmetic(), py::module_local(false))
        .value("RowMajor", shim::Layout::RowMajor)
        .value("ColMajor", shim::Layout::ColMajor)
        .export_values();

    // ------------------------------------------------------------------
    // TensorDesc / Tensor (shim::ai_tensor.hpp)
    // ------------------------------------------------------------------
    py::class_<shim::TensorDesc>(m, "TensorDesc", py::module_local(false))
        .def(py::init<>())
        .def_readwrite("dtype",  &shim::TensorDesc::dtype)
        .def_readwrite("layout", &shim::TensorDesc::layout)
        .def_readwrite("shape",  &shim::TensorDesc::shape)
        .def_readwrite("stride", &shim::TensorDesc::stride)
        .def("dim", &shim::TensorDesc::dim);

    py::class_<shim::Tensor>(m, "Tensor", py::module_local(false))
        .def(py::init<>())
        // raw pointer를 u64로 노출 (Python 쪽에서 CuPy ptr와 연결)
        .def_property(
            "data_u64",
            [](const shim::Tensor& t) {
                return static_cast<std::uintptr_t>(reinterpret_cast<std::uintptr_t>(t.data));
            },
            [](shim::Tensor& t, std::uintptr_t p) {
                t.data = reinterpret_cast<void*>(p);
            }
        )
        .def_readwrite("desc",         &shim::Tensor::desc)
        .def_readwrite("device",       &shim::Tensor::device)
        .def_readwrite("device_index", &shim::Tensor::device_index)
        .def("is_cuda", &shim::Tensor::is_cuda)
        .def("is_contiguous_rowmajor_2d", &shim::Tensor::is_contiguous_rowmajor_2d);

    // ------------------------------------------------------------------
    // make_tensor_2d : ptr_u64 + shape → shim::Tensor (2D, RowMajor)
    // ------------------------------------------------------------------
    m.def(
        "make_tensor_2d",
        [](std::uintptr_t ptr_u64,
           const std::vector<std::int64_t>& shape,
           shim::DType dtype = shim::DType::F32,
           shim::Device device = shim::Device::CUDA,
           int device_index = 0) {
            if (shape.size() != 2) {
                throw std::invalid_argument("make_tensor_2d: shape must be 2D");
            }
            const std::int64_t rows = shape[0];
            const std::int64_t cols = shape[1];

            // ai_tensor.hpp 의 shim::make_tensor2d 헬퍼 사용
            shim::Tensor t = shim::make_tensor2d(
                reinterpret_cast<void*>(ptr_u64),
                rows,
                cols,
                dtype
            );
            t.device       = device;
            t.device_index = device_index;
            return t;
        },
        py::arg("ptr_u64"),
        py::arg("shape"),
        py::arg("dtype")        = shim::DType::F32,
        py::arg("device")       = shim::Device::CUDA,
        py::arg("device_index") = 0
    );

    // ------------------------------------------------------------------
    // __all__
    // ------------------------------------------------------------------
    m.attr("__all__") = py::make_tuple(
        "ActKind", "GemmAttrs",
        "Device", "DType", "Layout",
        "TensorDesc", "Tensor", "make_tensor_2d"
    );
}