dev_AI_framework

다른 방법으로 계산 그래프의 구성 요소를 전부 N * N sparse matrix 로 표현하기 - 스칼라 단위까지 세분화하기

명징직조지훈 2025. 5. 13. 22:22

2025.05.13 - [dev_AI_framework] - 레이어 자체가 자기 행렬 조각을 반환( 1번째 방법, 인덱스로 연결 관계 표현하기 )

Conn (N, N) 노드 i → 노드 j 연결이면 Conn[i][j] = 1
Value (N,) or (N, D) 각 노드의 출력값 (벡터 or 행렬)
Grad (N,) or (N, D)  노드의 역전파 기울기
OpType (N,) 각 노드의 연산자 종류 (정수 인코딩)
import numpy as np
import pandas as pd

# Define op type codes
OP_TYPES = {
    "const": 0,
    "multiply": 1,
    "add": 2,
}

# A (4x4) @ B (4x4) = C (4x4)
# We express this with (N x N) matrices only: Conn, OpType, ParamIndex

# Total nodes: 32 const (A,B) + 64 mul + 48 add = 144
N = 144
Conn = np.zeros((N, N), dtype=np.int8)
OpType = np.zeros((N,), dtype=np.int32)
ParamIndex = np.full((N,), -1, dtype=np.int32)  # Only const nodes will get param index

# Map of parameters
param_values = []

# 1. A and B constants: node 0~31
for i in range(32):
    OpType[i] = OP_TYPES["const"]
    ParamIndex[i] = len(param_values)
    param_values.append(np.random.rand(1).item())  # scalar value

# 2. Multiply nodes (32~95): A[i][k] * B[k][j]
# Each C[i][j] has 4 multiplications, so total 16x4 = 64
idx = 32
for i in range(4):
    for j in range(4):
        for k in range(4):
            a_idx = i * 4 + k       # A[i][k] → node 0~15
            b_idx = 16 + k * 4 + j  # B[k][j] → node 16~31
            Conn[a_idx, idx] = 1
            Conn[b_idx, idx] = 1
            OpType[idx] = OP_TYPES["multiply"]
            idx += 1

# 3. Add nodes (96~143): sum 4 muls per output
# Each sum takes 3 additions: (((a + b) + c) + d)
for block in range(16):  # for each C[i][j]
    base = 32 + block * 4  # 4 mul outputs per C[i][j]
    a1 = idx
    Conn[base, a1] = 1
    Conn[base + 1, a1] = 1
    OpType[a1] = OP_TYPES["add"]
    idx += 1

    a2 = idx
    Conn[a1, a2] = 1
    Conn[base + 2, a2] = 1
    OpType[a2] = OP_TYPES["add"]
    idx += 1

    a3 = idx
    Conn[a2, a3] = 1
    Conn[base + 3, a3] = 1
    OpType[a3] = OP_TYPES["add"]
    idx += 1

# Final output nodes = 128~143
output_node_ids = list(range(128, 144))

df_preview = pd.DataFrame({
    "OpType": OpType[:30],
    "ParamIndex": ParamIndex[:30]
}).T

print("=== OpType Vector & ParamIndex Preview ===")
print(df_preview)

print("\n=== Conn Matrix (partial 30x30) ===")
print(pd.DataFrame(Conn[:50, :50]))

수정된 dense mat

import numpy as np

OP_TYPES = {
    "const": 0,
    "multiply": 1,
    "add": 2,
}

class DenseMat:
    def __init__(self, units, input_dim=None, initializer='he'):
        self.units = units
        self.input_dim = input_dim
        self.output_dim = units
        self.initializer = initializer
        self.weights = None
        self.bias = None

    def build(self, input_dim):
        self.input_dim = input_dim
        self.output_dim = self.units

        if self.initializer == 'he':
            stddev = np.sqrt(2. / input_dim)
            self.weights = (np.random.randn(input_dim, self.output_dim) * stddev).astype(np.float32)
        elif self.initializer == 'xavier':
            stddev = np.sqrt(1. / input_dim)
            self.weights = (np.random.randn(input_dim, self.output_dim) * stddev).astype(np.float32)
        else:
            self.weights = np.zeros((input_dim, self.output_dim), dtype=np.float32)

        self.bias = np.zeros((1, self.output_dim), dtype=np.float32)

    def generate_sparse_matrix_block(self, input_ids, node_offset):
        input_dim = self.input_dim
        output_dim = self.output_dim

        # 총 필요한 노드 수 계산
        weight_const_count = input_dim * output_dim
        bias_const_count = output_dim
        mul_count = input_dim * output_dim
        add_count = (input_dim - 1 + 1) * output_dim  # 3 mul-add + 1 bias-add per unit
        total_nodes = weight_const_count + bias_const_count + mul_count + add_count
        N = node_offset + total_nodes

        Conn = np.zeros((N, N), dtype=np.int8)
        OpType = np.zeros((N,), dtype=np.int32)
        ParamIndex = np.full((N,), -1, dtype=np.int32)
        ParamValues = []

        nid = node_offset

        # Weight constants
        weight_ids = np.zeros((input_dim, output_dim), dtype=np.int32)
        for i in range(input_dim):
            for j in range(output_dim):
                OpType[nid] = OP_TYPES["const"]
                ParamIndex[nid] = len(ParamValues)
                ParamValues.append(self.weights[i, j])
                weight_ids[i, j] = nid
                nid += 1

        # Bias constants
        bias_ids = []
        for j in range(output_dim):
            OpType[nid] = OP_TYPES["const"]
            ParamIndex[nid] = len(ParamValues)
            ParamValues.append(self.bias[0, j])
            bias_ids.append(nid)
            nid += 1

        # Multiply nodes
        mul_ids = np.zeros((input_dim, output_dim), dtype=np.int32)
        for j in range(output_dim):
            for i in range(input_dim):
                OpType[nid] = OP_TYPES["multiply"]
                Conn[input_ids[i], nid] = 1
                Conn[weight_ids[i, j], nid] = 1
                mul_ids[i, j] = nid
                nid += 1

        # Add chains for each output unit
        output_ids = []
        for j in range(output_dim):
            # first add: mul[0] + mul[1]
            add_prev = nid
            OpType[nid] = OP_TYPES["add"]
            Conn[mul_ids[0, j], nid] = 1
            Conn[mul_ids[1, j], nid] = 1
            nid += 1

            # second add: prev + mul[2]
            OpType[nid] = OP_TYPES["add"]
            Conn[add_prev, nid] = 1
            Conn[mul_ids[2, j], nid] = 1
            add_prev = nid
            nid += 1

            # third add: prev + mul[3]
            OpType[nid] = OP_TYPES["add"]
            Conn[add_prev, nid] = 1
            Conn[mul_ids[3, j], nid] = 1
            add_prev = nid
            nid += 1

            # final add with bias
            OpType[nid] = OP_TYPES["add"]
            Conn[add_prev, nid] = 1
            Conn[bias_ids[j], nid] = 1
            output_ids.append(nid)
            nid += 1

        return {
            "Conn": Conn,
            "OpType": OpType,
            "ParamIndex": ParamIndex,
            "ParamValues": ParamValues,
            "input_ids": input_ids,
            "output_ids": output_ids,
            "next_node_offset": nid
        }

 

계산 그래프 구조, 기존의 노드 연결 표현이 sparse matrix 로 표현되고 그에 기반한 여러 다른 행렬들의 생성, 해당 행렬들의 cuda 전송으로 연산 가능토록 할 예정