На этом шаге мы напомним кое-что из изученного материала.
Принцип постепенного раскрытия сложности - доступ к спектру рабочих процессов шаг за шагом, от предельно простых до бесконечно гибких - также применим для обучения моделей. Библиотека Keras предлагает разные подходы к обучению моделей, от несложных, таких как вызов fit() с обучающими данными, до продвинутых, связанных с разработкой новых алгоритмов обучения с нуля.
Вы уже знакомы с последовательностью вызовов compile(), fit(), evaluate(), predict(). Чтобы вспомнить ее, взгляните на следующий пример.
from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.datasets import mnist # Создание модели (вынесено в отдельную функцию, # чтобы иметь возможность использовать ее позже) def get_mnist_model(): inputs = keras.Input(shape=(28 * 28,)) features = layers.Dense(512, activation="relu")(inputs) features = layers.Dropout(0.5)(features) outputs = layers.Dense(10, activation="softmax")(features) model = keras.Model(inputs, outputs) return model # Загрузка данных для обучения и проверки (images, labels), (test_images, test_labels) = mnist.load_data() images = images.reshape((60000, 28 * 28)).astype("float32") / 255 test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255 train_images, val_images = images[10000:], images[:10000] train_labels, val_labels = labels[10000:], labels[:10000] model = get_mnist_model() # Компиляция модели с указанными оптимизатором, функцией потерь, # которая должна минимизироваться, и метриками для оценки model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) # Вызов fit() для обучения модели; при необходимости можно передать # проверочные данные для мониторинга качества модели на данных, # отличных от обучающих model.fit(train_images, train_labels, epochs=3, validation_data=(val_images, val_labels)) # Вызов evaluate() для вычисления потерь и метрик на # контрольных данных test_metrics = model.evaluate(test_images, test_labels) # Вызов predict() с контрольными данными для вычисления вероятностей # принадлежности к тому или иному классу predictions = model.predict(test_images)
Данный простой процесс можно скорректировать:
Посмотрим, как это сделать.
На следующем шаге мы рассмотрим использование собственных метрик.