Шаг 194.
Глубокое обучение на Python. Работа с Keras: глубокое погружение. Встроенные циклы обучения и оценки. Разработка своего обратного вызова

    На этом шаге мы рассмотрим создание такого вызова.

    Если в ходе обучения потребуется выполнить какие-то особые действия, не предусмотренные ни одним из встроенных обратных вызовов, можно написать свой обратный вызов. Обратные вызовы реализуются путем создания подкласса класса keras.callbacks.Callback. Вы можете реализовать любые из следующих методов с говорящими именами, которые будут вызываться в соответствующие моменты в ходе обучения:

on_epoch_begin  # Вызывается в начале каждой эпохи
on_epoch_end    # Вызывается в конце каждой эпохи
on_batch_begin  # Вызывается перед началом обработки каждого пакета 
on_batch_end    # Вызывается сразу после окончания обработки каждого пакета 
on_train_begin  # Вызывается в начале обучения 
on_train_end    # Вызывается в конце обучения

    Все эти методы вызываются с аргументом logs - словарем, содержащим информацию о предыдущем пакете, эпохе или цикле обучения (метрики обучения и проверки и т. д.). Методам on_epoch_* и on_batch_* также передается индекс эпохи или пакета в первом аргументе (целое число).

    Вот простой пример обратного вызова, который сохраняет список значений потерь для каждого пакета во время обучения и график изменения потерь в конце каждой эпохи.


Пример 7.20. Создание своего обратного вызова наследованием класса Callback
from matplotlib import pyplot as plt

class LossHistory(keras.callbacks.Callback):

  def on_train_begin(self, logs):
    self.per_batch_losses = []

  def on_batch_end(self, batch, logs):
    self.per_batch_losses.append(logs.get("loss"))

  def on_epoch_end(self, epoch, logs):
    plt.clf()
    plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
             label="Потери на обучающих данных для каждого пакета")
    plt.xlabel(f"Пакеты (эпоха {epoch})")
    plt.ylabel("Потери")
    plt.legend()
    plt.savefig(f"plot_at_epoch_{epoch}")
    self.per_batch_losses = []

    Испытаем его:


model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=10,
          callbacks=[LossHistory()],
          validation_data=(val_images, val_labels))

Блокнот с этим примером можно взять здесь.

    Сохраненный график можно увидеть на рисунке 1.


Рис.1. График, созданный нашим собственным обратным вызовом

    На следующем шаге мы рассмотрим мониторинг и визуализацию с помощью TensorBoard.




Предыдущий шаг Содержание Следующий шаг