Шаг 145.
Введение в машинное обучение с использованием Python. Оценка и улучшение качества модели. ... . Решетчатый поиск с перекрестной проверкой (окончание)

    На этом шаге мы рассмотрим, какие средства имеются в библиотеке для организации такого поиска.

    Поскольку решетчатый поиск с перекрестной проверкой является весьма распространенным методом настройки параметров, библиотека scikit-learn предлагает класс GridSearchCV, в котором решетчатый поиск реализован в виде модели. Чтобы воспользоваться классом GridSearchCV, сначала необходимо указать искомые параметры с помощью словаря. GridSearchCV построит все необходимые модели. Ключами словаря являются имена настраиваемых параметров (в данном случае С и gamma), а значениями - тестируемые настройки параметров. Перебор значений 0.001, 0.01, 0.1, 1, 10 и 100 для C и gamma требует словаря следующего вида:

[In 23]:
param_grid = {'C': [0.001, 0.001, 0.1, 1, 10, 100],
              'gamma': [0.001, 0.001, 0.1, 1, 10, 100]}
print("Сетка параметров:\n{}".format(param_grid))

Сетка параметров:
{'C': [0.001, 0.01, 0.1, 1, 10, 100], 'gamma': [0.001, 0.01, 0.1, 1, 10, 100]}

    Теперь мы можем создать экземпляр класса GridSearchCV, передав модель (SVC), сетку искомых параметров (param_grid), а также стратегию перекрестной проверки, которую мы хотим использовать (допустим, пятиблочную стратифицированную перекрестную проверку):

[In 24]:
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
grid_search = GridSearchCV(SVC(), param_grid, cv=5)

    Вместо разбиения на обучающий и проверочный набор, использованного нами ранее, GridSearchCV запустит перекрестную проверку. Однако нам по-прежнему нужно разделить данные на обучающий и тестовый наборы, чтобы избежать переобучения параметров:

[In 25]:
X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, random_state=0)

    Созданный нами объект grid_search аналогичен классификатору, мы можем вызвать стандартные методы fit(), predict() и score() от его имени.


Модель scikit-learn, которая создается с помощью другой модели называется метамоделью (meta-estimator). GridSearchCV является наиболее часто используемой метамоделью, но об этом мы поговорим позже.

    Однако, когда мы вызываем fit(), он запускает перекрестную проверку для каждой комбинации параметров, указанных в param_grid:

[In 26]:
grid_search.fit(X_train, y_train)

    Процесс подгонки объекта GridSearchCV включает в себя не только поиск лучших параметров, но и автоматическое построение новой модели на всем обучающем наборе данных. Для ее построения используются параметры, которые дают наилучшее значение правильности перекрестной проверки. Поэтому процесс, запускаемый вызовом метода fit(), эквивалентен программному коду в [In 20], который мы видели на 144 шаге. Класс GridSearchCV предлагает очень удобный интерфейс для работы с моделью, используя методы predict() и score(). Чтобы оценить обобщающую способность найденных наилучших параметров, мы можем вызвать метод score():

[In 27]:
print("Правильность на тестовом наборе: {:.2f}".format(grid_search.score(X_test, y_test)))

Правильность на тестовом наборе: 0.97

    Выбрав параметры с помощью перекрестной проверки, мы фактически нашли модель, которая достигает правильности 97% на тестовом наборе. Главный момент здесь в том, что мы не использовали тестовый набор для отбора параметров. Найденная комбинация параметров сохраняется в атрибуте best_params_, а наилучшее значение правильности перекрестной проверки (значение правильности, усредненное по всем разбиениям для данной комбинации параметров) - в атрибуте best_score_.

[In 28]:
print("Наилучшие значения параметров: {}".format(grid_search.best_params_))
print("Наилучшее значение кросс-валидац. правильности: {:.2f}".format(grid_search.best_score_))

Наилучшие значения параметров: {'C': 10, 'gamma': 0.1}
Наилучшее значение кросс-валидац. правильности: 0.97


Опять же, будьте осторожны, чтобы не перепутать best_score_ со значением обобщающей способности модели, которое вычисляется на тестовом наборе с помощью метода score(). Метод score() (оценивающий качество результатов, полученных с помощью метода predict()) использует модель, построенную на всем обучающем наборе данных. В атрибуте best_score_ записывается средняя правильность перекрестной проверки. Для ее вычисления используется модель, построенная на обучающем наборе перекрестной проверки.

    В ряде случаев вам необходимо будет ознакомиться с полученной моделью, например, взглянуть на коэффициенты или важности признаков. Посмотреть наилучшую модель, построенную на всем обучающем наборе, вы можете с помощью атрибута best_estimator_:

[In 29]:
print("Наилучшая модель:\n{}".format(grid_search.best_estimator_))

SVC(C=100, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.01, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

    Поскольку grid_search уже сам по себе включает методы predict() и score(), использование best_estimator_ для получения прогнозов и оценки качества модели не требуется.

    На следующем шаге мы проведем анализ результатов перекрестной проверки.




Предыдущий шаг Содержание Следующий шаг