Шаг 199.
Глубокое обучение на Python. Работа с Keras: глубокое погружение. Разработка своего цикла обучения и оценки. Полный цикл обучения и оценки

    На этом шаге мы реализуем этот цикл.

    Давайте теперь объединим прямой проход, обратное распространение ошибки и отслеживание метрик в шаговую функцию (training step function), подобную fit(), которая принимает пакет данных и цели и возвращает сведения, которые будут отображаться в индикаторе выполнения fit().


Пример 7.21. Разработка своего цикла обучения: шаговая функция
model = get_mnist_model()

# Подготовка функции потерь
loss_fn = keras.losses.SparseCategoricalCrossentropy()
# Подготовка оптимизатора
optimizer = keras.optimizers.RMSprop()
# Подготовка списка метрик для отслеживания
metrics = [keras.metrics.SparseCategoricalAccuracy()]
# Подготовка метрики Mean для слежения за средним значением потерь
loss_tracking_metric = keras.metrics.Mean()

# Выполнение прямого прохода. Обратите внимание на аргумент training=True
def train_step(inputs, targets): 
  with tf.GradientTape() as tape: 
    predictions = model(inputs, training=True) 
    loss = loss_fn(targets, predictions) 
  # Обратное распространение ошибки. Обратите внимание, что здесь 
  # используется model.trainable_weights
  gradients = tape.gradient(loss, model.trainable_weights) 
  optimizer.apply_gradients(zip(gradients, model.trainable_weights)) 
  
  # Слежение за метриками
  logs = {} 
  for metric in metrics: 
    metric.update_state(targets, predictions) 
    logs[metric.name] = metric.result() 
    
  # Слежение за средним значением потерь
  loss_tracking_metric.update_state(loss) 
  logs["loss"] = loss_tracking_metric.result() 
  # Возврат текущих значений метрик и потерь
  return logs

    Важно не забыть сбросить состояние метрик в начале каждой эпохи и перед началом этапа оценки. Вот вспомогательная функция, которая сделает это.


Пример 7.22. Разработка своего цикла обучения: сброс метрик
def reset_metrics():
  for metric in metrics:
    metric.reset_state()
  loss_tracking_metric.reset_state()

    Теперь можно закончить реализацию цикла обучения. Обратите внимание, что здесь используется объект tf.data.Dataset, превращающий массив NumPy с данными в итератор, который выполняет итерации по данным пакетами размером 32.


Пример 7.23. Разработка своего цикла обучения: сам цикл
import tensorflow as tf

training_dataset = tf.data.Dataset.from_tensor_slices(
    (train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3
for epoch in range(epochs):
  reset_metrics()
  for inputs_batch, targets_batch in training_dataset:
    logs = train_step(inputs_batch, targets_batch)
  print(f"Results at the end of epoch {epoch}")
  for key, value in logs.items():
    print(f"...{key}: {value:.4f}")

    Ниже приводится цикл оценки: простой цикл for, многократно вызывающий функцию test_step(), которая обрабатывает один пакет данных. Функция test_ step() - лишь подмножество логики train_step(). В ней отсутствует код, обновляющий веса модели, то есть все, что связано с GradientTape и оптимизатором.


Пример 7.24. Разработка своего цикла обучения: цикл оценки
def test_step(inputs, targets): 
  # Обратите внимание на аргумент training=False
  predictions = model(inputs, training=False) 
  loss = loss_fn(targets, predictions) 
  
  logs = {} 
  for metric in metrics: 
    metric.update_state(targets, predictions) 
    logs["val_" + metric.name] = metric.result() 
  
  loss_tracking_metric.update_state(loss) 
  logs["val_loss"] = loss_tracking_metric.result() 
  return logs
  
val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
for inputs_batch, targets_batch in val_dataset: 
  logs = test_step(inputs_batch, targets_batch)
print("Evaluation results:")
for key, value in logs.items(): 
  print(f"...{key}: {value:.4f}")

Блокнот с этим примером можно взять здесь.

    Поздравляем, вы только что реализовали свои полноценные версии функций fit() и evaluate()! Ну или почти полноценные: на самом деле fit() и evaluate() реализуют множество других возможностей, включая крупномасштабные распределенные вычисления, которые требуют немного больше работы. Они также включают некоторые важные оптимизации производительности.

    Давайте рассмотрим одну из таких оптимизаций: компиляцию функции TensorFlow.

    На следующем шаге мы рассмотрим ускорение вычислений с помощью tf.function.




Предыдущий шаг Содержание Следующий шаг