<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          官方 | Keras分布式訓(xùn)練教程

          共 8407字,需瀏覽 17分鐘

           ·

          2021-08-20 22:58


          點(diǎn)擊上方小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)

          總覽

          tf.distribute.Strategy API提供了一種抽象,用于在多個處理單元之間分布您的訓(xùn)練。目的是允許用戶以最小的更改使用現(xiàn)有模型和培訓(xùn)代碼來進(jìn)行分布式培訓(xùn)。


          本教程使用tf.distribute.MirroredStrategy,它在一臺機(jī)器上的多個GPU上進(jìn)行同步訓(xùn)練的圖內(nèi)復(fù)制。本質(zhì)上,它將所有模型變量復(fù)制到每個處理器。然后,它使用all-reduce組合所有處理器的梯度,并將組合后的值應(yīng)用于模型的所有副本。


          MirroredStategy是TensorFlow核心中可用的幾種分發(fā)策略之一。您可以在分發(fā)策略指南中了解更多策略。


          Keras API


          本示例使用tf.keras API構(gòu)建模型和訓(xùn)練循環(huán)。有關(guān)自定義訓(xùn)練循環(huán),請參閱帶有訓(xùn)練循環(huán)的tf.distribute.Strategy教程。


          Keras API

          This example uses the tf.keras API to build the model and training loop. For custom training loops, see the tf.distribute.Strategy with training loops tutorial.

          Import dependencies

          from __future__ import absolute_import, division, print_function, unicode_literals
          # Import TensorFlow and TensorFlow Datasetstry: !pip install -q tf-nightlyexceptException: passimport tensorflow_datasets as tfdsimport tensorflow as tftfds.disable_progress_bar()import os
          print(tf.__version__)
          2.1.0-dev20191004

          Download the dataset

          Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data format.

          Setting with_info to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.

          datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
          mnist_train, mnist_test = datasets['train'], datasets['test']
          Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...
          /usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)
          WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.Instructions for updating:Use eager execution and:`tf.data.TFRecordDataset(path)`
          WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.Instructions for updating:Use eager execution and:`tf.data.TFRecordDataset(path)`
          Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.

          Define distribution strategy


          Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf.distribute.MirroredStrategy.scope) to build your model inside.

          strategy = tf.distribute.MirroredStrategy()
          INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
          INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
          print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
          Number of devices: 1

          Setup input pipeline

          When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

          # You can also do info.splits.total_num_examples to get the total# number of examples in the dataset.
          num_train_examples = info.splits['train'].num_examplesnum_test_examples = info.splits['test'].num_examples
          BUFFER_SIZE = 10000
          BATCH_SIZE_PER_REPLICA = 64BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync


          Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function.

          def scale(image, label):  image = tf.cast(image, tf.float32)  image /= 255
          return image, label


          Apply this function to the training and test data, shuffle the training data, and batch it for training. Notice we are also keeping an in-memory cache of the training data to improve performance.

          train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

          Create the model

          Create and compile the Keras model in the context of strategy.scope.

          with strategy.scope():  model = tf.keras.Sequential([      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),      tf.keras.layers.MaxPooling2D(),      tf.keras.layers.Flatten(),      tf.keras.layers.Dense(64, activation='relu'),      tf.keras.layers.Dense(10, activation='softmax')  ])
          model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
          INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
          INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
          INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
          INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

          Define the callbacks

          The callbacks used here are:

          • TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.

          • Model Checkpoint: This callback saves the model after every epoch.

          • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

          For illustrative purposes, add a print callback to display the learning rate in the notebook.

          # Define the checkpoint directory to store the checkpoints
          checkpoint_dir = './training_checkpoints'# Name of the checkpoint filescheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
          # Function for decaying the learning rate.# You can define any decay function you need.def decay(epoch):  if epoch < 3:    return1e-3  elif epoch >= 3and epoch < 7:    return1e-4  else:    return1e-5
          # Callback for printing the LR at the end of each epoch.classPrintLR(tf.keras.callbacks.Callback):  def on_epoch_end(self, epoch, logs=None):    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,                                                  
          callbacks = [    tf.keras.callbacks.TensorBoard(log_dir='./logs'),    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,                                       save_weights_only=True),    tf.keras.callbacks.LearningRateScheduler(decay),    PrintLR()]


          Train and evaluate

          Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.


          下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
          在「小白學(xué)視覺」公眾號后臺回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講
          小白學(xué)視覺公眾號后臺回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。

          下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
          小白學(xué)視覺公眾號后臺回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講即可下載含有20個基于OpenCV實(shí)現(xiàn)20個實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

          交流群


          歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~


          瀏覽 79
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評論
          圖片
          表情
          推薦
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  久久乐精品 | 人人草人人插 | 欧美一级黄色直播间 | 91黄色片 | 国产操逼的视频 |