<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          手把手教你實現(xiàn)GAN半監(jiān)督學(xué)習(xí)

          共 11786字,需瀏覽 24分鐘

           ·

          2021-10-18 01:01

          點擊上方小白學(xué)視覺”,選擇加"星標"或“置頂

          重磅干貨,第一時間送達


          引言?


          本文主要介紹如何在tensorflow上僅使用200個帶標簽的mnist圖像,實現(xiàn)在一萬張測試圖片上99%的測試精度,原理在于使用GAN做半監(jiān)督學(xué)習(xí)。前文主要介紹一些原理部分,后文詳細介紹代碼及其實現(xiàn)原理。前文介紹比較簡單,有基礎(chǔ)的同學(xué)請掠過直接看第二部分,文章末尾給出了代碼GitHub鏈接。對GAN不了解的同學(xué)可以查看微信公眾號:機器學(xué)習(xí)算法全棧工程師的GAN入門文章。


          監(jiān)督,無監(jiān)督,半監(jiān)督學(xué)習(xí)介紹


          在正式介紹實現(xiàn)半監(jiān)督學(xué)習(xí)之前,我在這里首先介紹一下監(jiān)督學(xué)習(xí)(supervised learning),半監(jiān)督學(xué)習(xí)(semi-supervised learning)和無監(jiān)督學(xué)習(xí)(unsupervised learning)的區(qū)別。監(jiān)督學(xué)習(xí)是指在訓(xùn)練集中包含訓(xùn)練數(shù)據(jù)的標簽(label),比如類別標簽,位置標簽等等。最普遍使用標簽學(xué)習(xí)的是分類任務(wù),對于分類任務(wù),輸入給網(wǎng)絡(luò)訓(xùn)練樣本(samples)的一些特征(feature)以及此樣本對應(yīng)的標簽(label),通過神經(jīng)網(wǎng)絡(luò)擬合的方法,神經(jīng)網(wǎng)絡(luò)可以在特征和標簽之間找到一個合適的映射關(guān)系(mapping),這樣當訓(xùn)練完成后,輸入給網(wǎng)絡(luò)沒有l(wèi)abel的樣本,神經(jīng)網(wǎng)絡(luò)可以通過這一個映射關(guān)系猜出它屬于哪一類。典型機器學(xué)習(xí)的監(jiān)督學(xué)習(xí)的例子是KNN和SVM。目前機器視覺領(lǐng)域的急速發(fā)展離不開監(jiān)督學(xué)習(xí)。


          而無監(jiān)督學(xué)習(xí)的訓(xùn)練事先沒有訓(xùn)練標簽,直接輸入給算法一些數(shù)據(jù),算法會努力學(xué)習(xí)數(shù)據(jù)的共同點,尋找樣本之間的規(guī)律性。無監(jiān)督學(xué)習(xí)是很典型的學(xué)習(xí),人的學(xué)習(xí)有時候就是基于無監(jiān)督的,比如我并不懂音樂,但是我聽了上百首歌曲后,我可以根據(jù)我聽的結(jié)果將音樂分為搖滾樂(記為0類)、民謠(記為1類)、純音樂(記為2類)等等,事實上,我并不知道具體是哪一類,所以將它們記為0,1,2三類。典型的無監(jiān)督學(xué)習(xí)方法是聚類算法,比如k-means。


          東方快車電影里面大偵探有過一個臺詞,人們的話只有對與錯,沒有中間地帶,最后經(jīng)過一系列事件后他找到了對與錯之間的betweeness。在監(jiān)督學(xué)習(xí)和無監(jiān)督學(xué)習(xí)之間,同樣存在著中間地帶-半監(jiān)督學(xué)習(xí)。半監(jiān)督學(xué)習(xí)簡單來說就是將無監(jiān)督學(xué)習(xí)和監(jiān)督學(xué)習(xí)相結(jié)合,一部分包含了監(jiān)督學(xué)習(xí)一部分包含了無監(jiān)督學(xué)習(xí),比如給一個分類任務(wù),此分類任務(wù)的訓(xùn)練集中有精確標簽的數(shù)據(jù)非常少,但是包含了大量的沒有標注的數(shù)據(jù),如果直接用監(jiān)督學(xué)習(xí)的方法去做的話,效果不一定很好,有標注的訓(xùn)練數(shù)據(jù)太少很容易導(dǎo)致過擬合,而且大量的無標注的數(shù)據(jù)都沒有充分的利用,最常見的例子是在醫(yī)學(xué)圖像的分析檢測任務(wù)中,醫(yī)學(xué)圖像本身就不容易獲得,要獲得精標注的圖像就需要有經(jīng)驗的醫(yī)生去一個一個標注,顯然他們并沒有那么多的時間。這時候就是半監(jiān)督學(xué)習(xí)的用武之地了,半監(jiān)督學(xué)習(xí)很適合用在標簽數(shù)據(jù)少,訓(xùn)練數(shù)據(jù)又比較多的情況。


          常見的半監(jiān)督學(xué)習(xí)方法主要有:

          1.Self training

          2.Generative model

          3.S3VMs

          4.Graph-Based AIgorithems

          5.Multiview AIgorithems


          接下來我會結(jié)合Improved Techniques for Training GANs這篇論文詳細介紹如何使用目前最火的生成模型GAN去實現(xiàn)半監(jiān)督學(xué)習(xí),也即是半監(jiān)督學(xué)習(xí)的第二種方法,并給出詳細的代碼解釋,對理論不是很熟悉的同學(xué)可以直接看代碼。另外注明:我只復(fù)現(xiàn)了論文半監(jiān)督學(xué)習(xí)的部分,之前也有人復(fù)現(xiàn)了此部分,但是我感覺他對原文有很大的曲解,他使用了所有的標簽去幫助生成,并不在分類上,不太符合半監(jiān)督學(xué)習(xí)的本質(zhì),而且代碼很復(fù)雜,感興趣的可以看這個鏈接https://github.com/gitlimlab/SSGAN-Tensorflow。


          Improved Techniques for Training GANs


          GAN是無監(jiān)督學(xué)習(xí)的代表,它可以不斷學(xué)習(xí)模擬數(shù)據(jù)的分布進而生成和訓(xùn)練數(shù)據(jù)相似分布的樣本,在訓(xùn)練過程不需要標簽,GAN在無監(jiān)督學(xué)習(xí)領(lǐng)域,生成領(lǐng)域,半監(jiān)督學(xué)習(xí)領(lǐng)域以及強化學(xué)習(xí)領(lǐng)域都有廣泛的應(yīng)用。但是GAN存在很多的訓(xùn)練不穩(wěn)定等等的問題,作者good fellow在2016年放出了Improved Techniques for Training GANs,對GAN訓(xùn)練不穩(wěn)定的問題做了一些解釋和經(jīng)驗上的解決方案,并給出了和半監(jiān)督學(xué)習(xí)結(jié)合的方法。


          從平衡點角度解釋GAN的不穩(wěn)定性來說,GAN的納什均衡點是一個鞍點,并不是一個局部最小值點,基于梯度的方法主要是尋找高維空間中的極小值點,因此使用梯度訓(xùn)練的方法很難使GAN收斂到平衡點。為此,為了一部分緩解這個問題,goodfellow聯(lián)合提出了一些改進方案,

          主要有:

          Feature matching,

          Minibatch discrimination

          weight Historical averaging? (相當于一個正則化的方式)

          One-sided label smoothing

          Virtual batch normalization


          后來發(fā)現(xiàn)Feature matching在半監(jiān)督學(xué)習(xí)上表現(xiàn)良好,mini-batch discrimination表現(xiàn)很差。


          ?semi-supervised GAN


          對于一個普通的分類器來說,假設(shè)對MNIST分類,一共有10類數(shù)據(jù),分別是0-9,分類器模型以數(shù)據(jù)x作為輸入,輸出一個K=10維的向量,經(jīng)過soft max后計算出分類概率最大的那個類別。在監(jiān)督學(xué)習(xí)領(lǐng)域,往往是通過最小化類別標簽 y 和預(yù)測分布的交叉熵來實現(xiàn)最好的結(jié)果。


          但是將GAN用在半監(jiān)督學(xué)習(xí)領(lǐng)域的時候需要做一些改變,生成器不做改變,仍然負責從輸入噪聲數(shù)據(jù)中生成圖像,判別器D不在是一個簡單的真假分類(二分類)器,假設(shè)輸入數(shù)據(jù)有K類,D就是K+1的分類器,多出的那一類是判別輸入是否是生成器G生成的圖像。網(wǎng)絡(luò)的流程圖見圖一。




          圖一 網(wǎng)絡(luò)的流程圖


          網(wǎng)絡(luò)結(jié)構(gòu)確定了之后就是損失函數(shù)的設(shè)計部分,借助GAN我們就可以從無標簽數(shù)據(jù)中學(xué)習(xí),只要知道輸入數(shù)據(jù)是真實數(shù)據(jù),那就可以通過最大化來實現(xiàn),上述式子可解釋為不管輸入的是哪一類真的圖片(不是生成器G生成的假圖片),只要最大化輸出它是真圖像的概率就可以了,不需要具體分出是哪一類。由于GAN的生成器的參與,訓(xùn)練數(shù)據(jù)中有一半都是生成的假數(shù)據(jù)。


          下面給出判別器D的損失函數(shù)設(shè)計,D損失函數(shù)包括兩個部分,一個是監(jiān)督學(xué)習(xí)損失,一個是半監(jiān)督學(xué)習(xí)損失,具體公式如下:


          其中




          對于無監(jiān)督學(xué)習(xí)來說,只需要輸出真假就可以了,不需要確定是哪一類,因此我們令



          其中表示判別是假圖像的概率,那么D(x)就代表了輸出是真圖像的概率,那么無監(jiān)督學(xué)習(xí)的損失函數(shù)就可以表示為



          這不就是GAN的損失函數(shù)嘛!好了,到這里得出結(jié)論,在半監(jiān)督學(xué)習(xí)中,判別器的分類要多分一類,多出的這一類表示的是生成器生成的假圖像這一類,另外判別器的損失函數(shù)不僅包括了監(jiān)督損失而且還有無監(jiān)督的損失函數(shù),在訓(xùn)練過程中同時最小化這兩者。損失函數(shù)介紹完畢,接下來介紹代碼實現(xiàn)部分。


          代碼實現(xiàn)及解讀


          注:完整代碼的GitHub連接在文章底部。這里只截取關(guān)鍵部分做介紹。


          在代碼中,我使用feature matching,one side label smoothing方式,并沒有使用論文中介紹的Historical averaging,而是只對判別器D使用了簡單的l2正則化,防止過擬合,另外論文中介紹的Minibatch discrimination, Virtual batch normalization等等都沒有使用,主要是這兩者在半監(jiān)督學(xué)習(xí)中表現(xiàn)不是很好,但是如果想獲得好的生成結(jié)果還是很有用的。


          1網(wǎng)絡(luò)結(jié)構(gòu)

          首先介紹網(wǎng)絡(luò)結(jié)構(gòu)部分,因為是在mnist數(shù)據(jù)集比較簡單,所以隨便搭了一個判別器和生成器,具體如下:


          判別器的網(wǎng)絡(luò)結(jié)構(gòu)如下面代碼所示:



          def discriminator(self, name, inputs, reuse):
          ? ? ? ?l = tf.shape(inputs)[0]
          ? ? ? ?inputs = tf.reshape(inputs, (l,self.img_size,self.img_size,self.dim))
          ? ? ? ?with tf.variable_scope(name,reuse=reuse):
          ? ? ? ? ? ?out = []
          ? ? ? ? ? ?output = conv2d('d_con1',inputs,5, 64, stride=2, padding='SAME') #14*14
          ? ? ? ? ? ?output1 = lrelu(self.bn('d_bn1',output))
          ? ? ? ? ? ?out.append(output1)
          ? ? ? ? ? ?# output1 = tf.contrib.keras.layers.GaussianNoise
          ? ? ? ? ? ?output = conv2d('d_con2', output1, 3, 64*2, stride=2, padding='SAME')#7*7
          ? ? ? ? ? ?output2 = lrelu(self.bn('d_bn2', output))
          ? ? ? ? ? ?out.append(output2)
          ? ? ? ? ? ?output = conv2d('d_con3', output2, 3, 64*4, stride=1, padding='VALID')#5*5
          ? ? ? ? ? ?output3 = lrelu(self.bn('d_bn3', output))
          ? ? ? ? ? ?out.append(output3)
          ? ? ? ? ? ?output = conv2d('d_con4', output3, 3, 64*4, stride=2, padding='VALID')#2*2
          ? ? ? ? ? ?output4 = lrelu(self.bn('d_bn4', output))
          ? ? ? ? ? ?out.append(output4)
          ? ? ? ? ? ?output = tf.reshape(output4, [l, 2*2*64*4])# 2*2*64*4
          ? ? ? ? ? ?output = fc('d_fc', output, self.num_class)
          ? ? ? ? ? ?# output = tf.nn.softmax(output)
          ? ? ? ? ? ?return output, out


          其中conv2d()是卷積操作,參數(shù)依次是,層的名字,輸入tensor,卷積核大小,輸出通道數(shù),步長,padding。判別器中每一層都加了歸一化層,這里使用最簡單的歸一化,函數(shù)如下所示,另外每一層的激活函數(shù)使用leakrelu。判別器D最終返回兩個值,第一個是計算的logits,另外一個是一個列表,列表的每一個元素代表判別器每一層的輸出,為接下來實現(xiàn)feature matching做準備。


          def bn(self, name, input):
          ? ? ? ?val = tf.contrib.layers.batch_norm(input, decay=0.9,
          ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? updates_collections=None,
          ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? epsilon=1e-5,
          ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? scale=True,
          ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? is_training=True,
          ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? scope=name)
          ? ? ? ?return val


          def lrelu(x, leak=0.2):
          ? ?return tf.maximum(x, leak * x)


          生成器結(jié)構(gòu)如下面代碼所示:其最后一層激活函數(shù)使用tanh


          def generator(self,name, noise, reuse):
          ? ? ? ?with tf.variable_scope(name,reuse=reuse):
          ? ? ? ? ? ?l = self.batch_size
          ? ? ? ? ? ?output = fc('g_dc', noise, 2*2*64)
          ? ? ? ? ? ?output = tf.reshape(output, [-1, 2, 2, 64])
          ? ? ? ? ? ?output = tf.nn.relu(self.bn('g_bn1',output))
          ? ? ? ? ? ?output = deconv2d('g_dcon1',output,5,outshape=[l, 4, 4, 64*4])
          ? ? ? ? ? ?output = tf.nn.relu(self.bn('g_bn2',output))

          ? ? ? ? ? ?output = deconv2d('g_dcon2', output, 5, outshape=[l, 8, 8, 64 * 2])
          ? ? ? ? ? ?output = tf.nn.relu(self.bn('g_bn3', output))

          ? ? ? ? ? ?output = deconv2d('g_dcon3', output, 5, outshape=[l, 16, 16,64 * 1])
          ? ? ? ? ? ?output = tf.nn.relu(self.bn('g_bn4', output))

          ? ? ? ? ? ?output = deconv2d('g_dcon4', output, 5, outshape=[l, 32, 32, self.dim])
          ? ? ? ? ? ?output = tf.image.resize_images(output, (28, 28))
          ? ? ? ? ? ?# output = tf.nn.relu(self.bn('g_bn4', output))
          ? ? ? ? ? ?return tf.nn.tanh(output)


          網(wǎng)絡(luò)結(jié)構(gòu)是根據(jù)DCGAN的結(jié)構(gòu)改的,所以網(wǎng)絡(luò)簡要介紹到這里。


          2網(wǎng)絡(luò)初始化

          接下來介紹網(wǎng)絡(luò)初始化方面:

          首先在train.py里建立一個Train的類,并做一些初始化


          class Train(object):
          ? ?def __init__(self, sess, args):
          ? ? ? ?#sess=tf.Session()
          ? ? ? ?self.sess = sess
          ? ? ? ?self.img_size = 28 ? # the size of image
          ? ? ? ?self.trainable = True
          ? ? ? ?self.batch_size = 100 ?# must be even number
          ? ? ? ?self.lr = 0.0002
          ? ? ? ?self.mm = 0.5 ? ? ?# momentum term for adam
          ? ? ? ?self.z_dim = 128 ? # the dimension of noise z
          ? ? ? ?self.EPOCH = 50 ? ?# the number of max epoch
          ? ? ? ?self.LAMBDA = 0.1 ?# parameter of WGAN-GP
          ? ? ? ?self.model = args.model ?# 'DCGAN' or 'WGAN'
          ? ? ? ?self.dim = 1 ? ? ? # RGB is different with gray pic
          ? ? ? ?self.num_class = 11
          ? ? ? ?self.load_model = args.load_model
          ? ? ? ?self.build_model() ?# initializer


          args是傳進來的參數(shù),主要包括三個,一個是args.model,選擇DCGAN模式還是WGAN-GP模式,二者的不同主要在于損失函數(shù)不同和優(yōu)化器的學(xué)習(xí)率不同,其他都一樣。第二個參數(shù)是args.trainable,訓(xùn)練還是測試,訓(xùn)練時為True,測試是False。Loadmodel表示是否選擇加載訓(xùn)練好的權(quán)重。


          import argparse
          parser.add_argument('--model', type=str, default='DCGAN', help='DCGAN or WGAN-GP')
          parser.add_argument('--trainable', type=bool, default=False,help='True for train and False for test')
          parser.add_argument('--load_model', type=bool, default=True, help='True for load ckpt model and False for otherwise')
          parser.add_argument('--label_num', type=int, default=2, help='the num of labled images we use, 2*100=200,batchsize:100')


          3Build_model函數(shù)

          Build_model函數(shù)里面主要包括了網(wǎng)絡(luò)訓(xùn)練前的準備工作,主要包括損失函數(shù)的設(shè)計和優(yōu)化器的設(shè)計。以下代碼連在一起正好是build_model函數(shù)的全部內(nèi)容,下文將詳細做出介紹,尤其是損失函數(shù)部分。


          def build_model(self):
          ? ? ? ?# build ?placeholders
          ? ? ? ?self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.img_size*self.img_size*self.dim], name='real_img')
          ? ? ? ?self.z = tf.placeholder(tf.float32, shape=[self.batch_size, self.z_dim], name='noise')
          ? ? ? ?self.label = tf.placeholder(tf.float32, shape=[self.batch_size, self.num_class-1], name='label')
          ? ? ? ?self.flag = tf.placeholder(tf.float32, shape=[], name='flag')
          ? ? ? ?self.flag2 = tf.placeholder(tf.float32, shape=[], name='flag2')
          ? ? ? ?# define the network
          ? ? ? ?self.G_img = self.generator('gen', self.z, reuse=False)
          ? ? ? ?ximg = tf.reshape(self.x, (self.batch_size, self.img_size, self.img_size, self.dim))
          ? ? ? ?d_in = tf.concat([ximg, self.G_img], axis=0)

          ? ? ? ?self.D_logits_, self.D_out_ = self.discriminator('dis', d_in, reuse=False)

          ? ? ? ?self.D_logits, self.D_logits_f = tf.split(self.D_logits_, [self.batch_size, self.batch_size], axis=0)

          ? ? ? ?d_regular = tf.add_n(tf.get_collection('regularizer', 'dis'), 'loss')
          ? ? ? #caculate the supervised loss
          ? ? ? ?batch_gl = tf.zeros_like(self.label, dtype=tf.float32)
          ? ? ? ?batchl_ = tf.concat([self.label, tf.zeros([self.batch_size, 1])], axis=1)
          ? ? ? ?batch_gl = tf.concat([batch_gl, tf.ones([self.batch_size, 1])], axis=1)
          ? ? ? ?batchl = tf.concat([batchl_, batch_gl], axis=0)*0.9 ?# one side label smoothing
          ? ? ? ? s_l = tf.losses.softmax_cross_entropy(onehot_labels=batchl, logits=self.D_logits_, label_smoothing=None)
          ? ? ? ?s_logits_ = tf.nn.softmax(self.D_logits_)
          ? ? ? ?un_s = tf.reduce_sum(s_logits_[:self.batch_size, -1])/(tf.reduce_sum(s_logits_[:self.batch_size,:])) \
          ? ? ? ? ? ? ? ?+ tf.reduce_sum(s_logits_[self.batch_size:,:-1])/tf.reduce_sum(s_logits_[self.batch_size:,:])
          ? ? ? ?f_match = tf.constant(0., dtype=tf.float32)
          ? ? ? ?for i in range(4):
          ? ? ? ? ? ?d_layer, d_glayer = tf.split(self.D_out_[i], [self.batch_size, self.batch_size], axis=0)
          ? ? ? ? ? ?f_match += tf.reduce_mean(tf.multiply(tf.subtract(d_layer, d_glayer),tf.subtract(d_layer, d_glayer)))
          ? ? ? ?self.d_loss_real = -tf.log(tf.reduce_sum(s_logits_[:self.batch_size, :-1])/tf.reduce_sum(s_logits_[:self.batch_size, :]))
          ? ? ? ? ? ?self.d_loss_fake = -tf.log(tf.reduce_sum(s_logits_[self.batch_size:, -1])/tf.reduce_sum(s_logits_[self.batch_size:, :]))
          ? ? ? ? ? ?self.g_loss = self.d_loss_fake + f_match*0.01*self.flag2
          ? ? ? ? ? ?self.d_l_1, self.d_l_2, self.d_l_3 = self.d_loss_fake + self.d_loss_real, self.flag*s_l, (1-self.flag)*un_s
          ? ? ? ? ? ?self.d_loss = self.d_l_1 + self.d_l_2 + self.d_l_3


          首先,建立了五個placeholder,flag表示兩個標志位,只有0-1兩種情況,注意到我num_class是11,也就是做11分類,但是lable的placeholder中shape是(batchsize,10),因為傳進去訓(xùn)練之前會將label擴展到[batchsize, 11]。為了方便,我將生成器的生成結(jié)果和真實數(shù)據(jù)X級聯(lián)在一起作為判別器的輸入,輸出再把他它們結(jié)果split分開。


          d_regular 表示正則化,這里我將判別器中所有的weights做了l2正則。


          監(jiān)督學(xué)習(xí)的損失函數(shù)使用常見的交叉熵損失函數(shù),對生成器生成的圖像的label的one_hot型為:

          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]


          將原始的label擴展到(batchsize,11)后再和生成器生成的假數(shù)據(jù)的label再第一維度concat到一起得到batchl,另外乘以0.9,做單邊標簽平滑(one side smoothing),由此計算得到監(jiān)督學(xué)習(xí)的損失函數(shù)值s_l,。


          生成器G的損失函數(shù)


          生成器G的損失函數(shù)包括兩部分,一個是來自GAN訓(xùn)練的部分,另外一個是feature matching , 論文中提到的feature matching意思是特征匹配,主要思想是希望生成器生成的假數(shù)據(jù)輸入到判別器,經(jīng)過判別器每一層計算的結(jié)果和將真實數(shù)據(jù)X輸入到判別器,判別器每一層的結(jié)果盡可能的相似,公式如下:



          ???????????????????????????????????? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?

          其中f(x)是D的每一層的輸出。Feature matching 是指導(dǎo)G進行訓(xùn)練,所以我將他放在了G的損失函數(shù)里。


          分類器D的損失函數(shù)


          相比較G的損失函數(shù),D的損失函數(shù)就比較麻煩了。


          接下來介紹無監(jiān)督學(xué)習(xí)的損失函數(shù)實現(xiàn):

          在前面介紹的無監(jiān)督學(xué)習(xí)的損失函數(shù)中,有一部分和GAN的損失函數(shù)很相似,所以再代碼中我們使用了無監(jiān)督學(xué)習(xí)的時候沒有標簽的指導(dǎo),此時判別器或者稱為分類器D無法正確對輸入進行分類,此時只要求D能夠區(qū)分真假就可以了,由此我們得到了無監(jiān)督學(xué)習(xí)的損失un_s,直觀上也很好理解,假設(shè)輸入給判別器D真圖像,它結(jié)果經(jīng)過soft max后輸出類似下面表格的形式


































          其中前十個黃色區(qū)域表示對0-9的分類概率,最后一個灰色的表示對假圖像的分類概率,由于無監(jiān)督學(xué)習(xí)中判別器D并不知道具體是哪一類數(shù)據(jù),所以干脆D的損失函數(shù)最小化輸出假圖像的概率就可以了,當輸入為生成器生成的假圖像時,只要最小化D輸出為真圖像的概率,由此我們得到了un_s.。但是此時有一個問題,即是有監(jiān)督學(xué)習(xí)的時候不就沒有用了嗎,因為這時候應(yīng)該使用s_l.為了解決這個問題,我使用了一個標志位flag作為控制他們之間的使用,具體代碼:


          flag*s_l + ( 1 – flag)*un_s


          有標簽的時候flag是1,表示使用s_l,無監(jiān)督的時候flag是0,表示使用無監(jiān)督損失函數(shù)。此時已經(jīng)完成了判別器D損失函數(shù)的一部分設(shè)計,剩下的一部分和GAN中的D的損失一樣,在代碼中我給出了兩種損失函數(shù),一個是原始GAN的交叉熵損失函數(shù),和DCGAN使用的一樣,另外一個是improved wgan論文中使用的損失函數(shù),但是在做了對比之后,我強烈建議使用DCGAN來做,improved wgan的損失函數(shù)雖然在生成結(jié)果的優(yōu)化上有很大幫助,但是并不適合半監(jiān)督學(xué)習(xí)中。


          訓(xùn)練部分


          接下來就是訓(xùn)練部分:


          此時可能有一個疑問,我們是如何實現(xiàn)只使用200帶標簽的數(shù)據(jù)訓(xùn)練的,答案就在flag這個標志位里,在訓(xùn)練部分代碼中,當?shù)螖?shù)小于2的時候,flag=1, 此時表示使用s_l作為損失函數(shù)的一部分,當flag=0的時候,un_s起作用而s_l并沒有起作用,這時,即使我們feed了正確的標簽數(shù)據(jù),但是s_l不起作用,就相當于沒有使用標簽。



          for idx in range(iters):
          ? ? ? start_t = time.time()
          ? ? ? flag = 1 if idxelse
          0 # set we use 500 train data with label.


          flag2的作用本來是使用他控制feature matching是否工作的,這里暫時設(shè)置為1。


          (訓(xùn)練部分詳細代碼請移步文章下面github鏈接查看)

          測試


          def test(self):
          ? ? ? ?count = 0.
          ? ? ? ?print 'testing................'
          ? ? ? ?for i in range(10000//self.batch_size):
          ? ? ? ? ? ?testx, textl = mnist.test.next_batch(self.batch_size)
          ? ? ? ? ? ?prediction = self.sess.run(self.prediction, feed_dict={self.x:testx, self.label:textl})
          ? ? ? ? ? ?count += np.sum(prediction)
          ? ? ? ?return count/10000.


          測試部分代碼如上圖所示,沒訓(xùn)練完成一個epoch,就測試依次,測試的時候,使用了一個temp保存測試的最大精度,當測試結(jié)果比前幾次都要好是,temp會更新到最好的測試精度,并保存模型,否則不保存模型,這樣做的好處在于我保存的模型測試精度一定是最好的。


          測試精度結(jié)果變化圖



          下載1:OpenCV-Contrib擴展模塊中文版教程
          在「小白學(xué)視覺」公眾號后臺回復(fù):擴展模塊中文教程即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺實戰(zhàn)項目52講
          小白學(xué)視覺公眾號后臺回復(fù):Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學(xué)校計算機視覺。

          下載3:OpenCV實戰(zhàn)項目20講
          小白學(xué)視覺公眾號后臺回復(fù):OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學(xué)習(xí)進階。

          交流群


          歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~


          瀏覽 54
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  91爱福利 | 可以直接看av的网址 | 成人免费视频 国产免费观看 | 成人性生活影视av | 91无码在线成人视频 |