【TensorFlow】筆記:基礎(chǔ)知識-自定義層
點擊上方“公眾號”可訂閱哦!
在TensorFlow2.0中,任何一個自定義層都繼承自tf.keras.layers.Layer。
Layer層中需要自定義的函數(shù)有很多,但是在實際使用時一般只需要定義那些必須使用的函數(shù)即可,以下是對__init__、build和call三個主要函數(shù)的小結(jié)。
01
__init__函數(shù)
class Mylayer(tf.keras.layers.Layer): # 顯示繼承自Layer層def __init__(self, unit): # init中顯示地確定參數(shù)super().__init__() # 調(diào)用父層self.unit = unit # 把參數(shù)加載到類
02
build函數(shù)
build()?可自定義網(wǎng)絡(luò)的權(quán)重的維度,可以根據(jù)輸入來指定權(quán)重的維度。
def build(self, input_shape):self.weight = self.add_weight(shape=(input_shape[-1], self.unit),initializer=tf.keras.initializers.RandomNormal(),trainable=True)self.bias = self.add_weight(shape=(self.unit,),initializer=tf.keras.initializers.Zeros(),trainable=True)
在Layer()?類中有一個__call__()?魔法方法(上述兩個函數(shù)已經(jīng)被tf集成在該函數(shù)下面),會被自動調(diào)用,因此不用外部調(diào)用。
03
call函數(shù)
call函數(shù)是最重要的函數(shù),這部分代碼包含了主要層的實現(xiàn),即完成前向傳播。
init函數(shù),定義并聲明參數(shù),build函數(shù)聲明了權(quán)重可變參數(shù),而這只是定義了一些初始化的參數(shù)以及一些需要更新的參數(shù)變量,真正實現(xiàn)所定義類的功能是在call函數(shù)中。
def call(self, inputs):return tf.matmul(inputs, self.weight) + self.bias
call中的一系列操作是對init和build中變量的引用,所有的計算在call中完成。
輸入的參數(shù)在這里出現(xiàn),經(jīng)過計算后將計算值返回。
完整代碼:
import tensorflow as tfclass MyLayer(tf.keras.Model):def __init__(self, unit=32):super(MyLayer, self).__init__()self.unit = unitdef build(self, input_shape):self.weight = self.add_weight(shape=(input_shape[-1], self.unit),initializer=tf.keras.initializers.RandomNormal(),trainable=True)self.bias = self.add_weight(shape=(self.unit,),initializer=tf.keras.initializers.Zeros(),trainable=True)def call(self, inputs):return tf.matmul(inputs, self.weight) + self.biasmy_layer = MyLayer(3)x = tf.ones((3,5))out = my_layer(x)print(out)
輸出:
tf.Tensor([][][]], shape=(3, 3), dtype=float32)
?END

深度學(xué)習(xí)入門筆記
微信號:sdxx_rmbj
日常更新學(xué)習(xí)筆記、論文簡述
評論
圖片
表情
