На этом шаге мы рассмотрим создание такого вызова.
Если в ходе обучения потребуется выполнить какие-то особые действия, не предусмотренные ни одним из встроенных обратных вызовов, можно написать свой обратный вызов. Обратные вызовы реализуются путем создания подкласса класса keras.callbacks.Callback. Вы можете реализовать любые из следующих методов с говорящими именами, которые будут вызываться в соответствующие моменты в ходе обучения:
on_epoch_begin # Вызывается в начале каждой эпохи on_epoch_end # Вызывается в конце каждой эпохи on_batch_begin # Вызывается перед началом обработки каждого пакета on_batch_end # Вызывается сразу после окончания обработки каждого пакета on_train_begin # Вызывается в начале обучения on_train_end # Вызывается в конце обучения
Все эти методы вызываются с аргументом logs - словарем, содержащим информацию о предыдущем пакете, эпохе или цикле обучения (метрики обучения и проверки и т. д.). Методам on_epoch_* и on_batch_* также передается индекс эпохи или пакета в первом аргументе (целое число).
Вот простой пример обратного вызова, который сохраняет список значений потерь для каждого пакета во время обучения и график изменения потерь в конце каждой эпохи.
from matplotlib import pyplot as plt class LossHistory(keras.callbacks.Callback): def on_train_begin(self, logs): self.per_batch_losses = [] def on_batch_end(self, batch, logs): self.per_batch_losses.append(logs.get("loss")) def on_epoch_end(self, epoch, logs): plt.clf() plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses, label="Потери на обучающих данных для каждого пакета") plt.xlabel(f"Пакеты (эпоха {epoch})") plt.ylabel("Потери") plt.legend() plt.savefig(f"plot_at_epoch_{epoch}") self.per_batch_losses = []
Испытаем его:
model = get_mnist_model() model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) model.fit(train_images, train_labels, epochs=10, callbacks=[LossHistory()], validation_data=(val_images, val_labels))
Сохраненный график можно увидеть на рисунке 1.
Рис.1. График, созданный нашим собственным обратным вызовом
На следующем шаге мы рассмотрим мониторинг и визуализацию с помощью TensorBoard.