<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          【TensorFlow】筆記:基礎(chǔ)知識-自定義層

          共 2398字,需瀏覽 5分鐘

           ·

          2021-02-11 02:51

          點擊上方“公眾號”可訂閱哦!



          在TensorFlow2.0中,任何一個自定義層都繼承自tf.keras.layers.Layer。

          Layer層中需要自定義的函數(shù)有很多,但是在實際使用時一般只需要定義那些必須使用的函數(shù)即可,以下是對__init__、build和call三個主要函數(shù)的小結(jié)。



          01

          __init__函數(shù)

          __init__ 函數(shù)首先是一些必要參數(shù)的初始化,這些參數(shù)的初始化寫在 def __init__(self,) 中,然后是一些參數(shù)的初始化。
          class Mylayer(tf.keras.layers.Layer):     # 顯示繼承自Layer層    def __init__(self, unit):             # init中顯示地確定參數(shù)        super().__init__()                # 調(diào)用父層        self.unit = unit                  # 把參數(shù)加載到類
          ?
          init函數(shù)最重要的就是顯式的確定需要的一些參數(shù)。對于輸入的init中的參數(shù),輸入Tensor是不會在這里進行標注的,init初始化的是模型參數(shù)。

          ?



          02

          build函數(shù)


          build()?可自定義網(wǎng)絡(luò)的權(quán)重的維度,可以根據(jù)輸入來指定權(quán)重的維度。


              def build(self, input_shape):        self.weight = self.add_weight(shape=(input_shape[-1], self.unit),                                     initializer=tf.keras.initializers.RandomNormal(),                                     trainable=True)        self.bias = self.add_weight(shape=(self.unit,),                                   initializer=tf.keras.initializers.Zeros(),                                   trainable=True)
          ?


          在Layer()?類中有一個__call__()?魔法方法(上述兩個函數(shù)已經(jīng)被tf集成在該函數(shù)下面),會被自動調(diào)用,因此不用外部調(diào)用。




          03

          call函數(shù)


          call函數(shù)是最重要的函數(shù),這部分代碼包含了主要層的實現(xiàn),即完成前向傳播。

          init函數(shù),定義并聲明參數(shù),build函數(shù)聲明了權(quán)重可變參數(shù),而這只是定義了一些初始化的參數(shù)以及一些需要更新的參數(shù)變量,真正實現(xiàn)所定義類的功能是在call函數(shù)中。


              def call(self, inputs):        return tf.matmul(inputs, self.weight) + self.bias


          call中的一系列操作是對init和build中變量的引用,所有的計算在call中完成。


          輸入的參數(shù)在這里出現(xiàn),經(jīng)過計算后將計算值返回。



          完整代碼:

          import tensorflow as tf
          class MyLayer(tf.keras.Model): def __init__(self, unit=32): super(MyLayer, self).__init__() self.unit = unit
          def build(self, input_shape): self.weight = self.add_weight(shape=(input_shape[-1], self.unit), initializer=tf.keras.initializers.RandomNormal(), trainable=True) self.bias = self.add_weight(shape=(self.unit,), initializer=tf.keras.initializers.Zeros(), trainable=True) def call(self, inputs): return tf.matmul(inputs, self.weight) + self.bias
          my_layer = MyLayer(3)x = tf.ones((3,5))out = my_layer(x)print(out)

          輸出:

          tf.Tensor([[ 0.16174725 -0.03372785 -0.01657906] [ 0.16174725 -0.03372785 -0.01657906] [ 0.16174725 -0.03372785 -0.01657906]], shape=(3, 3), dtype=float32)





          ?END

          深度學(xué)習(xí)入門筆記

          微信號:sdxx_rmbj

          日常更新學(xué)習(xí)筆記、論文簡述

          瀏覽 111
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  苍井空视频免费一区二区三区 | 欧美69成人视频在线 | 九哥操逼网站 | 五月婷婷久久怎么了呀 | 国产精品V亚洲精品V日韩精品 |