Шаг 201.
Глубокое обучение на Python. ... . Разработка своего цикла обучения и оценки. Использование fit() с нестандартным циклом обучения

    На этом шаге мы рассмотрим пример совместного использования fit() и нестандартного цикла обучения.

    Ранее мы с нуля написали полный цикл обучения. Этот подход дает максимальную гибкость, но не только требует написать много кода, но и лишает множества удобных возможностей fit(), таких как обратные вызовы или встроенная поддержка распределенного обучения.

    А получится ли применить свой алгоритм обучения и сохранить всю мощь встроенной логики обучения Keras? На самом деле существует золотая середина между использованием fit() и реализацией своего цикла обучения: можно написать свою функцию шага обучения, а все остальные задачи переложить на фреймворк.

    Для этого достаточно переопределить метод train_step() класса Model, который вызывается функцией fit() для обработки каждого пакета данных, и использовать fit() как обычно, а функция будет запускать ваш алгоритм обучения.

    Вот простой пример:


Пример 7.26. Реализация своего шага обучения для использования с fit()
# Данный объект метрики будет использоваться для слежения за 
# средним значением потерь на пакетах в ходе обучения и оценки
loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss_tracker = keras.metrics.Mean(name="loss")

class CustomModel(keras.Model): 
  
  # Мы переопределяем метод train_step
  def train_step(self, data): 
    inputs, targets = data 
    with tf.GradientTape() as tape: 
      # Здесь вместо model(inputs, training=True) используется 
      # self(inputs, training=True), потому что моделью является 
      # сам экземпляр класса
      predictions = self(inputs, training=True) 
      loss = loss_fn(targets, predictions) 
      
    gradients = tape.gradient(loss, model.trainable_weights) 
    optimizer.apply_gradients(zip(gradients, model.trainable_weights)) 
    
    # Обновить метрику потерь, в которой хранится среднее значение потерь
    loss_tracker.update_state(loss) 
    # Вернуть среднее значение потерь, получившееся к данному моменту, 
    # обратившись к экземпляру метрики loss_tracker
    return {"loss": loss_tracker.result()}
    
@property
# Список всех метрик, которые должны сбрасываться в 
# исходное состояние в начале каждой эпохи
def metrics(self):
  return [loss_tracker]

    Теперь можно создать экземпляр модели, скомпилировать ее (в данном случае мы передаем только оптимизатор, потому что потери определены вне модели) и обучить, используя fit() как обычно:


inputs = keras.Input(shape=(28 * 28,))
features = layers.Dense(512, activation="relu")(inputs)
features = layers.Dropout(0.5)(features)
outputs = layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)

model.compile(optimizer=keras.optimizers.RMSprop())
model.fit(train_images, train_labels, epochs=3)

    Отметим несколько важных моментов:

    А что насчет метрик и функции потерь, которые настраиваются с помощью compile()? После вызова compile() вы получаете доступ к:

    То есть мы можем написать такой класс:


class CustomModel(keras.Model): 
  
  def train_step(self, data): 
    inputs, targets = data 
    with tf.GradientTape() as tape: 
      predictions = self(inputs, training=True) 
      # Вычислить величину потерь вызовом self.compiled_loss
      loss = self.compiled_loss(targets, predictions) 
      
    gradients = tape.gradient(loss, model.trainable_weights) 
    optimizer.apply_gradients(zip(gradients, model.trainable_weights)) 
    # Обновить метрики модели с помощью обертки self.compiled_metrics
    self.compiled_metrics.update_state(targets, predictions) 
    # Вернуть словарь, отображающий имена метрик в их текущие значения
    return {m.name: m.result() for m in self.metrics}

    Давайте опробуем его:


inputs = keras.Input(shape=(28 * 28,))
features = layers.Dense(512, activation="relu")(inputs)
features = layers.Dropout(0.5)(features)
outputs = layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)

model.compile(optimizer=keras.optimizers.RMSprop(),
              loss=keras.losses.SparseCategoricalCrossentropy(),
              metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=3)

    В этих шагах было представлено много новой информации, зато теперь вы знаете практически все, что нужно, чтобы использовать Keras для создания почти любых моделей.

    На следующем шаге мы подведем итоги по изученному материалу.




Предыдущий шаг Содержание Следующий шаг