На этом шаге мы рассмотрим создание собственных метрик.
Метрики являются ключом к оценке качества модели - в частности, они позволяют измерить разницу качества модели на обучающих и контрольных данных. Метрики, обычно используемые для классификации и регрессии, уже включены в стандартный модуль keras.metrics, и в большинстве случаев вы будете брать именно их. Но иногда, особенно при решении необычных задач, вам может понадобиться умение писать свои метрики. Это просто!
Метрики в Keras являются подклассом класса keras.metrics.Metric. Подобно слоям, метрики имеют внутреннее состояние, хранящееся в переменных TensorFlow. Но, в отличие от слоев, эти переменные не обновляются на этапе обратного распространения, поэтому вам придется написать свою логику их обновления в методе update_state().
Вот пример простой метрики, измеряющей среднеквадратичную ошибку (Root Mean Squared Error, RMSE) .
import tensorflow as tf # Подкласс класса Metric class RootMeanSquaredError(keras.metrics.Metric): def __init__(self, name="rmse", **kwargs): # Определение в конструкторе переменных для хранения состояния. # По аналогии со слоями есть возможность использовать метод add_weight() super().__init__(name=name, **kwargs) self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros") self.total_samples = self.add_weight( name="total_samples", initializer="zeros", dtype="int32") # update_state() реализует логику обновления состояния. # Аргумент y_true - это цели (или метки) для одного пакета, # а y_pred представляет соответствующие прогнозы модели. # Аргумент sample_weight можно игнорировать - здесь он не используется def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1]) mse = tf.reduce_sum(tf.square(y_true - y_pred)) self.mse_sum.assign_add(mse) num_samples = tf.shape(y_pred)[0] self.total_samples.assign_add(num_samples)
Для получения текущего значения метрики нужно реализовать метод result():
def result(self): return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))
Также следует предоставить возможность сбросить метрику в исходное состояние без необходимости создавать ее повторно. Это позволит использовать одни и те же объекты метрик в разные эпохи обучения или на этапах и обучения, и оценки. Для этого достаточно реализовать метод reset_state():
def reset_state(self): self.mse_sum.assign(0.) self.total_samples.assign(0)
Нестандартные метрики используются точно так же, как встроенные. Давайте протестируем нашу метрику:
model = get_mnist_model() model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=[accuracy", RootMeanSquaredError()]) model.fit(train_images, train_labels, epochs=3, validation_data=(val_images, val_labels)) test_metrics = model.evaluate(test_images, test_labels)
Теперь индикатор выполнения fit() будет отображать среднеквадратичную ошибку модели.
На следующем шаге мы рассмотрим использование обратных вызовов.