개념 정리/implement

implement_crossValidation

명징직조지훈 2022. 12. 31. 00:06

교차 검증은 데이터 세트 D 를 k 개의 서로소 집합으로 나누는 것, 

임의의 데이터 생성

import numpy as np

a = np.zeros((700))
b = np.ones((300))

data = np.concatenate((a,b))

D 로부터 stratified 를 통해 나눈다. 

k 개의 부분 집합은 데이터 분포를 반영해야 한다.

def cross_validation(data, ratio, k):
  a, b = np.unique(data, return_counts=True)
  label_data = []
  train_data = []

  # 각 레이블에 대한 반복
  for i in range(0, a.shape[0]):
    label_data.append(data[data == a[i]])

    #각 레이블 별 데이터 세트 개수
    data_size = len(label_data[i])
    
    validation_data_count = int(data_size / k)

    label_data_np = np.array(label_data[i])

    for l in range(0, k):
      train_data_row = []
      for j in range(0, int(validation_data_count)):
        rand_int = np.random.randint(0,data_size - j)
        train_data_row.append(label_data_np[rand_int])
        label_data_np = np.delete(label_data_np, rand_int)      

      data_size = data_size - validation_data_count
      train_data.append(train_data_row)
  
  vaildation_data = []

  for i in range(0, k):
    validation = []
    validation.append(train_data[i])
    validation.append(train_data[i+k])
    vaildation_data.append(validation)

  return validation_data