На этом шаге мы приведем реализацию этого этапа.
Этап обучения - самая сложная часть процесса. Нам требуется скорректировать веса модели после обучения на одном пакете данных. Для этого нужно сделать следующее.
Для вычисления градиента используем объект GradientTape из библиотеки TensorFlow, который был представлен на 54 шаге:
def one_training_step(model, images_batch, labels_batch): # Выполнить "прямой проход" (вычислить прогноз модели в контексте GradientTape) with tf.GradientTape() as tape: predictions = model(images_batch) per_sample_losses = tf.keras.losses.sparse_categorical_crossentropy( labels_batch, predictions) average_loss = tf.reduce_mean(per_sample_losses) # Вычислить градиент потерь с учетом весов. Результат gradients — это список, # каждый элемент которого соответствует весу в списке model.weights gradients = tape.gradient(average_loss, model.weights) # Скорректировать веса с учетом градиентов (эту функцию мы определим ниже) update_weights(gradients, model.weights) return average_loss
Как вы уже знаете, цель шага "корректировки весов" (представленного в предыдущем листинге функцией update_weights) состоит в том, чтобы "чуть-чуть" скорректировать веса в направлении, которое уменьшит потери в этом пакете. Величина корректировки определяется "скоростью обучения", обычно небольшой. Самый простой способ реализовать функцию update_weights - вычесть gradient * learning_rate из каждого веса:
learning_rate = 1e-3 def update_weights(gradients, weights): for g, w in zip(gradients, weights): w.assign_sub(g * learning_rate) # assign_sub - это эквивалент оператора -= # для переменных TensorFlow
На практике вам редко придется задумываться о реализации корректировки вручную, потому что обычно для этого используется экземпляр оптимизатора из Keras, например:
from tensorflow.keras import optimizers optimizer = optimizers.SGD(learning_rate=1e-3) def update_weights(gradients, weights): optimizer.apply_gradients(zip(gradients, weights))
Теперь, покончив с этапом обучения на одном пакете, перейдем к реализации целой эпохи обучения.
На следующем шаге мы рассмотрим полный цикл обучения.