Шаг 191.
Глубокое обучение на Python. Работа с Keras: глубокое погружение. Встроенные циклы обучения и оценки. Использование собственных метрик

    На этом шаге мы рассмотрим создание собственных метрик.

    Метрики являются ключом к оценке качества модели - в частности, они позволяют измерить разницу качества модели на обучающих и контрольных данных. Метрики, обычно используемые для классификации и регрессии, уже включены в стандартный модуль keras.metrics, и в большинстве случаев вы будете брать именно их. Но иногда, особенно при решении необычных задач, вам может понадобиться умение писать свои метрики. Это просто!

    Метрики в Keras являются подклассом класса keras.metrics.Metric. Подобно слоям, метрики имеют внутреннее состояние, хранящееся в переменных TensorFlow. Но, в отличие от слоев, эти переменные не обновляются на этапе обратного распространения, поэтому вам придется написать свою логику их обновления в методе update_state().

    Вот пример простой метрики, измеряющей среднеквадратичную ошибку (Root Mean Squared Error, RMSE) .


Пример 7.18. Реализация метрики путем создания класса, производного от класса Metric
import tensorflow as tf

# Подкласс класса Metric
class RootMeanSquaredError(keras.metrics.Metric): 
  
  def __init__(self, name="rmse", **kwargs): 
    # Определение в конструкторе переменных для хранения состояния. 
    # По аналогии со слоями есть возможность использовать метод add_weight()
    super().__init__(name=name, **kwargs) 
    self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros") 
    self.total_samples = self.add_weight( 
        name="total_samples", initializer="zeros", dtype="int32") 
    
  # update_state() реализует логику обновления состояния. 
  # Аргумент y_true - это цели (или метки) для одного пакета, 
  # а y_pred представляет соответствующие прогнозы модели. 
  # Аргумент sample_weight можно игнорировать - здесь он не используется
  def update_state(self, y_true, y_pred, sample_weight=None): 
    y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1]) 
    mse = tf.reduce_sum(tf.square(y_true - y_pred)) 
    self.mse_sum.assign_add(mse) 
    num_samples = tf.shape(y_pred)[0] 
    self.total_samples.assign_add(num_samples)

    Для получения текущего значения метрики нужно реализовать метод result():


  def result(self):
    return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

    Также следует предоставить возможность сбросить метрику в исходное состояние без необходимости создавать ее повторно. Это позволит использовать одни и те же объекты метрик в разные эпохи обучения или на этапах и обучения, и оценки. Для этого достаточно реализовать метод reset_state():


  def reset_state(self):
    self.mse_sum.assign(0.)
    self.total_samples.assign(0)

    Нестандартные метрики используются точно так же, как встроенные. Давайте протестируем нашу метрику:


model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=[accuracy", RootMeanSquaredError()])
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)

    Теперь индикатор выполнения fit() будет отображать среднеквадратичную ошибку модели.

    На следующем шаге мы рассмотрим использование обратных вызовов.




Предыдущий шаг Содержание Следующий шаг