На этом шаге мы реализуем этот цикл.
Давайте теперь объединим прямой проход, обратное распространение ошибки и отслеживание метрик в шаговую функцию (training step function), подобную fit(), которая принимает пакет данных и цели и возвращает сведения, которые будут отображаться в индикаторе выполнения fit().
model = get_mnist_model() # Подготовка функции потерь loss_fn = keras.losses.SparseCategoricalCrossentropy() # Подготовка оптимизатора optimizer = keras.optimizers.RMSprop() # Подготовка списка метрик для отслеживания metrics = [keras.metrics.SparseCategoricalAccuracy()] # Подготовка метрики Mean для слежения за средним значением потерь loss_tracking_metric = keras.metrics.Mean() # Выполнение прямого прохода. Обратите внимание на аргумент training=True def train_step(inputs, targets): with tf.GradientTape() as tape: predictions = model(inputs, training=True) loss = loss_fn(targets, predictions) # Обратное распространение ошибки. Обратите внимание, что здесь # используется model.trainable_weights gradients = tape.gradient(loss, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights)) # Слежение за метриками logs = {} for metric in metrics: metric.update_state(targets, predictions) logs[metric.name] = metric.result() # Слежение за средним значением потерь loss_tracking_metric.update_state(loss) logs["loss"] = loss_tracking_metric.result() # Возврат текущих значений метрик и потерь return logs
Важно не забыть сбросить состояние метрик в начале каждой эпохи и перед началом этапа оценки. Вот вспомогательная функция, которая сделает это.
def reset_metrics(): for metric in metrics: metric.reset_state() loss_tracking_metric.reset_state()
Теперь можно закончить реализацию цикла обучения. Обратите внимание, что здесь используется объект tf.data.Dataset, превращающий массив NumPy с данными в итератор, который выполняет итерации по данным пакетами размером 32.
import tensorflow as tf training_dataset = tf.data.Dataset.from_tensor_slices( (train_images, train_labels)) training_dataset = training_dataset.batch(32) epochs = 3 for epoch in range(epochs): reset_metrics() for inputs_batch, targets_batch in training_dataset: logs = train_step(inputs_batch, targets_batch) print(f"Results at the end of epoch {epoch}") for key, value in logs.items(): print(f"...{key}: {value:.4f}")
Ниже приводится цикл оценки: простой цикл for, многократно вызывающий функцию test_step(), которая обрабатывает один пакет данных. Функция test_ step() - лишь подмножество логики train_step(). В ней отсутствует код, обновляющий веса модели, то есть все, что связано с GradientTape и оптимизатором.
def test_step(inputs, targets): # Обратите внимание на аргумент training=False predictions = model(inputs, training=False) loss = loss_fn(targets, predictions) logs = {} for metric in metrics: metric.update_state(targets, predictions) logs["val_" + metric.name] = metric.result() loss_tracking_metric.update_state(loss) logs["val_loss"] = loss_tracking_metric.result() return logs val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels)) val_dataset = val_dataset.batch(32) reset_metrics() for inputs_batch, targets_batch in val_dataset: logs = test_step(inputs_batch, targets_batch) print("Evaluation results:") for key, value in logs.items(): print(f"...{key}: {value:.4f}")
Поздравляем, вы только что реализовали свои полноценные версии функций fit() и evaluate()! Ну или почти полноценные: на самом деле fit() и evaluate() реализуют множество других возможностей, включая крупномасштабные распределенные вычисления, которые требуют немного больше работы. Они также включают некоторые важные оптимизации производительности.
Давайте рассмотрим одну из таких оптимизаций: компиляцию функции TensorFlow.
На следующем шаге мы рассмотрим ускорение вычислений с помощью tf.function.