Data Ensemble: k-fold

1. K-Fold 교차 검증 개요

  • 목적: 데이터셋을 효율적으로 활용하여 모델의 성능을 안정적으로 평가하고, 과적합을 방지.
  • 과정:
    1. 전체 데이터셋을 K개의 동일한 크기의 폴드(Fold)로 분할.
    2. 각 폴드를 한 번씩 검증 세트로 사용하고, 나머지 K-1개의 폴드를 훈련 세트로 사용하여 K번 모델 학습 및 평가.
    3. K개의 평가 결과를 평균하여 최종 모델 성능을 도출.

      2. 훈련 세트 및 검증 세트 분할

      1. 데이터 분할: - 전체 데이터셋을 무작위로 섞은 후, K개의 폴드로 - 각 폴드는 가능한 한 균등한 크기와 대표성을 가지도록 분할
      2. 교차 검증 반복: - 반복 (i = 1 to K):
        • 검증 세트: i번째 폴드.
        • 훈련 세트: 나머지 K-1개의 폴드.
        • 모델을 훈련 세트로 학습.
        • 검증 세트로 모델 평가 및 성능 기록.

          3. 평가 지표 계산

      3. 평가 지표 선택: - 문제의 특성에 따라 적절한 평가 지표 선택 (예: 정확도, 정밀도, 재현율, F1 점수, AUC 등).
      4. K번의 평가 결과 집계: - 각 반복에서 계산된 평가 지표를 기록. - 예: Accuracy_i, Precision_i, …, for i = 1 to K.
      5. 평균 및 분산 계산: - 최종 성능은 K개의 평가 지표의 평균과 분산으로 표현. - 예: 평균 정확도 = (Accuracy_1 + Accuracy_2 + … + Accuracy_K) / K.

        4. 최종 테스트 세트에 대한 예측

        교차 검증을 통해 최적의 모델을 선정한 후, 최종 테스트 세트에 대한 예측을 수행할 때 앙상블(Ensemble) 방법을 사용

  • 주요 앙상블 방법으로는
    • 소프트 앙상블(Soft Ensemble)
    • 하드 앙상블(Hard Ensemble)

      4.1 소프트 앙상블 (Soft Ensemble)

  • 개념: 여러 모델의 예측 확률을 평균내어 최종 예측을 도출.
  • 과정:
    1. K개의 폴드에서 각각 훈련된 K개의 모델을 사용.
    2. 각 모델이 테스트 세트에 대해 출력하는 예측 확률을 수집.
    3. K개의 확률을 평균내어 최종 확률을 계산.
    4. 평균 확률을 기준으로 클래스 레이블을 결정.
  • 장점:
    • 클래스 간 불확실성을 반영할 수 있음.
    • 더 부드러운 예측을 제공하여 성능 향상 가능.
  • 예시:
      import numpy as np
      # K개의 모델 예측 확률
      predictions = [model.predict_proba(X_test) for model in models]
      # 소프트 앙상블: 평균 확률 계산
      avg_prob = np.mean(predictions, axis=0)
      # 최종 예측 레이블
      final_pred = np.argmax(avg_prob, axis=1)
    

    4.2 하드 앙상블 (Hard Ensemble)

  • 개념: 여러 모델의 최종 예측 레이블을 투표하여 최종 예측을 도출.
  • 과정:
    1. K개의 폴드에서 각각 훈련된 K개의 모델을 사용.
    2. 각 모델이 테스트 세트에 대해 예측한 클래스 레이블을 수집.
    3. K개의 예측 레이블 중 다수결로 최종 예측 레이블을 결정.
  • 장점:
    • 구현이 간단하고 직관적.
    • 다수결을 통해 안정적인 예측 가능.
  • 예시:
      from scipy.stats import mode
      # K개의 모델 예측 레이블
      predictions = [model.predict(X_test) for model in models]
      # 하드 앙상블: 다수결 투표
      final_pred, _ = mode(predictions, axis=0)
      final_pred = final_pred.ravel()