본문 바로가기

dev_AI_framework

layer-flatten, 검증의 추가 필요

from dev.layers.layer import Layer
import numpy as np

class Flatten(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    # (batch_size, flattened_dim) 형태,
    # 전체 데이터에 대해 생각해야해, 데이터 하나에 대한 flatten 연산이 아님

    def call(self, inputs):
        # 데이터를 실제로 변환하는데 초점
        # 배치 크기가 어떻게 유지되는지 알 수 있는 부분
        return np.reshape(inputs, (inputs.shape[0], -1))

    def compute_output_shape(self, input_shape):
        # 입력 shape 를 기반으로 출력 shape 를 계산, 모델의 구조 정의
        return (input_shape[0], np.prod(input_shape[1:]))

compute_output_shape : 입력 shape 를 기준으로 출력 shape 를 계산하는, layer 쌓을 때 실행되는 부분, model.summary 등의 메서드에 해당 출력이 사용될 것

 

실제 연산이 수행되는 call, model.fit 에서 해당 call 이 수행됨