def unpack_x_y_sample_weight(data):
if isinstance(data, list):
data = tuple(data)
if isinstance(data, list):
data = tuple(data)
if not isinstance(data, tuple):
return (data, None, None)
elif len(data) == 1:
return (data[0], None, None)
elif len(data) == 2:
return (data[0], data[1], None)
elif len(data) == 3:
return (data[0], data[1], data[2])
error_msg = (
"Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
f"or `(x, y, sample_weight)`, found: {data}"
)
raise ValueError(error_msg)
def pack_x_y_sample_weight(x, y=None, sample_weight=None):
if y is None:
if not isinstance(x, (tuple, list)):
return x
else:
return (x,)
elif sample_weight is None:
return (x, y)
else:
return (x, y, sample_weight)
def check_data_cardinality(data):
pass
trainer 에 필요한 메서드들의 구현,
data 로 부터 x, y, weight 를 분리하거나
data 로 패키징,
데이터 검증 등등
'dev_AI_framework' 카테고리의 다른 글
Sequential 클래스 내 call 메서드와 관련 함수, 속성의 정의 (0) | 2024.08.28 |
---|---|
model.compile, Trainer(class) - 모델 컴파일, 훈련, 평가, 예측 (self(x) 와 call 메서드의 구현 : 클래스 상속 관계... , __call__ 이란?) (0) | 2024.08.28 |
L2 regulrization 완전 정복!! (0) | 2024.08.28 |
optimizer, model.compile 구현을 위해.. (0) | 2024.08.27 |
layer-flatten, 검증의 추가 필요 (1) | 2024.08.27 |