본문 바로가기

SeSAC 금융데이터 분석가/머신러닝

10/11 화

728x90

1. 머신러닝 작업 순서 및 학습 방법 별 대표적인 알고리즘들

 

2. GridSearchCV

교차 검증과 최적 하이퍼 파라미터 튜닝을 한 번에

from sklearn.model_selection import GridSearchCV

X_train, X_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size = 0.2, random_state = 121)
dtree = DecisionTreeClassifier()

# max_depth = 결정 트리의 최대 깊이, min_samples_splits =  자식 규칙 노드를 분할해 만들기 위한 최소한의 샘플 데이터 개수
parameters = {'max_depth':[1,2,3], 'min_samples_split':[2, 3]}

grid_dtree = GridSearchCV(dtree, param_grid = parameters, cv = 3, refit = True) # refit = True가 default. 가장 좋은 파라미터 설정으로 재학습 시킴
grid_dtree.fit(X_train, y_train)

scores_df = pd.DataFrame(grid_dtree.cv_results_)
scores_df[['params', 'mean_test_score', 'rank_test_score', \
          'split0_test_score', 'split1_test_score', 'split2_test_score']]
          
print('GridSearchCV 최적 파라미터:', grid_dtree.best_params_)
print('GridSearchCV 최고 정확도: {0:.4f}'.format(grid_dtree.best_score_))
# GridSearchCV 최적 파라미터: {'max_depth': 3, 'min_samples_split': 2}
# GridSearchCV 최고 정확도: 0.9750

estimator = grid_dtree.best_estimator_
pred = estimator.predict(X_test)
print('테스트 데이터 세트 정확도: {0:.4f}'.format(accuracy_score(y_test, pred)))
# 테스트 데이터 세트 정확도: 0.9667

728x90

'SeSAC 금융데이터 분석가 > 머신러닝' 카테고리의 다른 글

10/17 월  (0) 2022.10.28
10/14 금  (0) 2022.10.14
10/12 수  (0) 2022.10.12