官方 | Keras分布式訓(xùn)練教程
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)

總覽
tf.distribute.Strategy API提供了一種抽象,用于在多個處理單元之間分布您的訓(xùn)練。目的是允許用戶以最小的更改使用現(xiàn)有模型和培訓(xùn)代碼來進(jìn)行分布式培訓(xùn)。
本教程使用tf.distribute.MirroredStrategy,它在一臺機(jī)器上的多個GPU上進(jìn)行同步訓(xùn)練的圖內(nèi)復(fù)制。本質(zhì)上,它將所有模型變量復(fù)制到每個處理器。然后,它使用all-reduce組合所有處理器的梯度,并將組合后的值應(yīng)用于模型的所有副本。
MirroredStategy是TensorFlow核心中可用的幾種分發(fā)策略之一。您可以在分發(fā)策略指南中了解更多策略。
Keras API
本示例使用tf.keras API構(gòu)建模型和訓(xùn)練循環(huán)。有關(guān)自定義訓(xùn)練循環(huán),請參閱帶有訓(xùn)練循環(huán)的tf.distribute.Strategy教程。
This example uses the tf.keras API to build the model and training loop. For custom training loops, see the tf.distribute.Strategy with training loops tutorial.
Import dependencies
from __future__ import absolute_import, division, print_function, unicode_literals# Import TensorFlow and TensorFlow Datasetstry:!pip install -q tf-nightlyexceptException:passimport tensorflow_datasets as tfdsimport tensorflow as tftfds.disable_progress_bar()import os
print(tf.__version__)2.1.0-dev20191004Download the dataset
Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data format.
Setting with_info to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0.../usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warningsInsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warningsInsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warningsInsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warningsInsecureRequestWarning)WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.Instructions for updating:Use eager execution and:`tf.data.TFRecordDataset(path)`WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.Instructions for updating:Use eager execution and:`tf.data.TFRecordDataset(path)`Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.
Define distribution strategy
Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf.distribute.MirroredStrategy.scope) to build your model inside.
strategy = tf.distribute.MirroredStrategy()INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))Number of devices: 1Setup input pipeline
When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.
# You can also do info.splits.total_num_examples to get the total# number of examples in the dataset.num_train_examples = info.splits['train'].num_examplesnum_test_examples = info.splits['test'].num_examplesBUFFER_SIZE = 10000BATCH_SIZE_PER_REPLICA = 64BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function.
def scale(image, label):image = tf.cast(image, tf.float32)image /= 255return image, label
Apply this function to the training and test data, shuffle the training data, and batch it for training. Notice we are also keeping an in-memory cache of the training data to improve performance.
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
Create the model
Create and compile the Keras model in the context of strategy.scope.
with strategy.scope():model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Flatten(),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')])model.compile(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.Adam(),metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Define the callbacks
The callbacks used here are:
TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.
Model Checkpoint: This callback saves the model after every epoch.
Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.
For illustrative purposes, add a print callback to display the learning rate in the notebook.
# Define the checkpoint directory to store the checkpointscheckpoint_dir = './training_checkpoints'# Name of the checkpoint filescheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.# You can define any decay function you need.def decay(epoch):if epoch < 3:return1e-3elif epoch >= 3and epoch < 7:return1e-4else:return1e-5
# Callback for printing the LR at the end of each epoch.classPrintLR(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs'),tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,save_weights_only=True),tf.keras.callbacks.LearningRateScheduler(decay),PrintLR()]
Train and evaluate
Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

