На этом шаге мы рассмотрим использование этих обратных вызовов.
Многие аспекты обучения модели нельзя предсказать заранее - например, количество эпох, обеспечивающее оптимальное значение потерь на проверочном наборе. В примерах, приводившихся до сих пор, использовалась стратегия обучения с достаточно большим количеством эпох. Таким способом достигался эффект переобучения: когда сначала выполнялся первый прогон, чтобы выяснить необходимое количество эпох обучения, а затем второй - новый - с этим количеством. Конечно, данная стратегия довольно расточительная. Гораздо лучше было бы остановить обучение, как только выяснится, что оценка потерь на проверочном наборе перестала улучшаться. Это можно реализовать с использованием обратного вызова EarlyStopping.
Обратный вызов EarlyStopping прерывает процесс обучения, если находящаяся под наблюдением целевая метрика не улучшалась на протяжении заданного количества эпох. Он позволит остановить обучение после наступления эффекта переобучения и тем самым избежать повторного обучения модели для меньшего количества эпох. Данный обратный вызов обычно используется в комбинации с обратным вызовом ModelCheckpoint, который может сохранять состояние модели в ходе обучения (и при необходимости сохранять только лучшую модель: версию, достигшую лучшего качества к концу эпохи):
# Обратные вызовы передаются в модель через параметр callbacks # метода fit() в виде списка. Вы можете передать любое # количество обратных вызовов callbacks_list = [ # Прерывает обучение, когда качество модели перестает улучшаться keras.callbacks.EarlyStopping( # Следит за изменением точности модели на проверочных данных monitor="val_accuracy", # Прерывает обучение, если точность не улучшается в течение двух эпох patience=2, ), # Сохраняет текущие веса после каждой эпохи keras.callbacks.ModelCheckpoint( # Путь к файлу модели filepath="checkpoint_path.keras", # Эти два аргумента требуют, чтобы файл модели не перезаписывался, # если значение val_loss не улучшилось, что позволяет # сохранить только лучшую модель monitor="val_loss", save_best_only=True, ) ] model = get_mnist_model() model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", # Мы следим за точностью, поэтому она должна # быть частью набора метрик модели metrics=["accuracy"]) # Обратите внимание: поскольку обратный вызов следит за потерями # и точностью на проверочных данных, мы должны передать # validation_data в вызов fit() model.fit(train_images, train_labels, epochs=10, callbacks=callbacks_list, validation_data=(val_images, val_labels))
Помните, что модель всегда можно сохранить вручную после обучения: нужно лишь вызвать метод model.save('путь_к_файлу'). Чтобы загрузить сохраненную модель, просто примените:
model = keras.models.load_model("checkpoint_path.keras")
На следующем шаге мы рассмотрим разработку своего обратного вызова.