На этом шаге мы рассмотрим, какие средства имеются в библиотеке для организации такого поиска.
Поскольку решетчатый поиск с перекрестной проверкой является весьма распространенным методом настройки параметров, библиотека scikit-learn предлагает класс GridSearchCV, в котором решетчатый поиск реализован в виде модели. Чтобы воспользоваться классом GridSearchCV, сначала необходимо указать искомые параметры с помощью словаря. GridSearchCV построит все необходимые модели. Ключами словаря являются имена настраиваемых параметров (в данном случае С и gamma), а значениями - тестируемые настройки параметров. Перебор значений 0.001, 0.01, 0.1, 1, 10 и 100 для C и gamma требует словаря следующего вида:
[In 23]: param_grid = {'C': [0.001, 0.001, 0.1, 1, 10, 100], 'gamma': [0.001, 0.001, 0.1, 1, 10, 100]} print("Сетка параметров:\n{}".format(param_grid)) Сетка параметров: {'C': [0.001, 0.01, 0.1, 1, 10, 100], 'gamma': [0.001, 0.01, 0.1, 1, 10, 100]}
Теперь мы можем создать экземпляр класса GridSearchCV, передав модель (SVC), сетку искомых параметров (param_grid), а также стратегию перекрестной проверки, которую мы хотим использовать (допустим, пятиблочную стратифицированную перекрестную проверку):
[In 24]: from sklearn.model_selection import GridSearchCV from sklearn.svm import SVC grid_search = GridSearchCV(SVC(), param_grid, cv=5)
Вместо разбиения на обучающий и проверочный набор, использованного нами ранее, GridSearchCV запустит перекрестную проверку. Однако нам по-прежнему нужно разделить данные на обучающий и тестовый наборы, чтобы избежать переобучения параметров:
[In 25]: X_train, X_test, y_train, y_test = train_test_split( iris.data, iris.target, random_state=0)
Созданный нами объект grid_search аналогичен классификатору, мы можем вызвать стандартные методы fit(), predict() и score() от его имени.
Однако, когда мы вызываем fit(), он запускает перекрестную проверку для каждой комбинации параметров, указанных в param_grid:
[In 26]:
grid_search.fit(X_train, y_train)
Процесс подгонки объекта GridSearchCV включает в себя не только поиск лучших параметров, но и автоматическое построение новой модели на всем обучающем наборе данных. Для ее построения используются параметры, которые дают наилучшее значение правильности перекрестной проверки. Поэтому процесс, запускаемый вызовом метода fit(), эквивалентен программному коду в [In 20], который мы видели на 144 шаге. Класс GridSearchCV предлагает очень удобный интерфейс для работы с моделью, используя методы predict() и score(). Чтобы оценить обобщающую способность найденных наилучших параметров, мы можем вызвать метод score():
[In 27]: print("Правильность на тестовом наборе: {:.2f}".format(grid_search.score(X_test, y_test))) Правильность на тестовом наборе: 0.97
Выбрав параметры с помощью перекрестной проверки, мы фактически нашли модель, которая достигает правильности 97% на тестовом наборе. Главный момент здесь в том, что мы не использовали тестовый набор для отбора параметров. Найденная комбинация параметров сохраняется в атрибуте best_params_, а наилучшее значение правильности перекрестной проверки (значение правильности, усредненное по всем разбиениям для данной комбинации параметров) - в атрибуте best_score_.
[In 28]: print("Наилучшие значения параметров: {}".format(grid_search.best_params_)) print("Наилучшее значение кросс-валидац. правильности: {:.2f}".format(grid_search.best_score_)) Наилучшие значения параметров: {'C': 10, 'gamma': 0.1} Наилучшее значение кросс-валидац. правильности: 0.97
В ряде случаев вам необходимо будет ознакомиться с полученной моделью, например, взглянуть на коэффициенты или важности признаков. Посмотреть наилучшую модель, построенную на всем обучающем наборе, вы можете с помощью атрибута best_estimator_:
[In 29]: print("Наилучшая модель:\n{}".format(grid_search.best_estimator_)) SVC(C=100, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma=0.01, kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False)
Поскольку grid_search уже сам по себе включает методы predict() и score(), использование best_estimator_ для получения прогнозов и оценки качества модели не требуется.
На следующем шаге мы проведем анализ результатов перекрестной проверки.