На этом шаге мы рассмотрим особенности использования этого декоратора.
Возможно, вы заметили, что реализованные вами циклы работают значительно медленнее, чем встроенные функции fit() и evaluate(), несмотря на то что фактически реализуют ту же логику. Причина в том, что по умолчанию код TensorFlow выполняется построчно и немедленно, подобно коду NumPy или обычному коду Python. Немедленное выполнение упрощает отладку, но с точки зрения производительности далеко не оптимально.
Более полезным для производительности будет скомпилировать код TensorFlow в граф вычислений, который можно оптимизировать глобально, что не получится сделать при построчной интерпретации кода. Синтаксис применения такой оптимизации прост: добавьте @tf.function к любой функции, которую нужно скомпилировать перед выполнением, как показано в следующем примере.
@tf.function # Единственная новая строка def test_step(inputs, targets): # Обратите внимание на аргумент training=False predictions = model(inputs, training=False) loss = loss_fn(targets, predictions) logs = {} for metric in metrics: metric.update_state(targets, predictions) logs["val_" + metric.name] = metric.result() loss_tracking_metric.update_state(loss) logs["val_loss"] = loss_tracking_metric.result() return logs val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels)) val_dataset = val_dataset.batch(32) reset_metrics() for inputs_batch, targets_batch in val_dataset: logs = test_step(inputs_batch, targets_batch) print("Evaluation results:") for key, value in logs.items(): print(f"...{key}: {value:.4f}")
В Colab время выполнения цикла оценки уменьшилось с 1,8 до 0,8 секунды. Теперь он выполняется намного быстрее!
Помните, что в процессе отладки код лучше запускать без декоратора @tf.function. Так проще находить и устранять ошибки. Закончив отладку, код можно ускорить, добавив декоратор @tf.function перед функциями, реализующими шаг обучения и шаг оценки, или любыми другими функциями, для которых важна высокая производительность.
На следующем шаге мы рассмотрим использование fit() с нестандартным циклом обучения.