<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          CycleGAN 生成對(duì)抗網(wǎng)絡(luò)圖像處理工具

          共 18125字,需瀏覽 37分鐘

           ·

          2021-03-12 10:04

          1. GAN簡(jiǎn)介

          "干飯人,干飯魂,干飯都是人上人"。
          此GAN飯人非彼干飯人。本文要講的GAN是Goodfellow2014提出的生成產(chǎn)生對(duì)抗模型,即Generative Adversarial Nets。那么GAN到底有什么神奇的地方?
          常規(guī)的深度學(xué)習(xí)任務(wù)如圖像分類,目標(biāo)檢測(cè)以及語(yǔ)義分割或者實(shí)例分割,這些任務(wù)的結(jié)果都可以歸結(jié)為預(yù)測(cè)。圖像分類是預(yù)測(cè)單一的類別,目標(biāo)檢測(cè)是預(yù)測(cè)bbox和類別,語(yǔ)義分割或者實(shí)例分割是預(yù)測(cè)每個(gè)像素的類別。而GAN是生成一個(gè)新的東西如一個(gè)圖片。
          GAN的原理用一句話來(lái)說(shuō)明:
          • 通過(guò)對(duì)抗的方式,去學(xué)習(xí)數(shù)據(jù)分布的生成式模型。GAN是無(wú)監(jiān)督的過(guò)程,能夠捕捉數(shù)據(jù)集的分布,以便于可以從隨機(jī)噪聲中生成同樣分布的數(shù)據(jù)
          GAN的組成:判別式模型和生成式模型的左右手博弈
          • D判別式模型:學(xué)習(xí)真假邊界,判斷數(shù)據(jù)是真的還是假的
          • G生成式模型:學(xué)習(xí)數(shù)據(jù)分布并生成數(shù)據(jù)

          GAN經(jīng)典的loss如下(minmax體現(xiàn)的就是對(duì)抗)

          2. 實(shí)戰(zhàn)cycleGAN 風(fēng)格轉(zhuǎn)換

          了解了GAN的作用,來(lái)體驗(yàn)的GAN的神奇效果。這里以cycleGAN為例子來(lái)實(shí)現(xiàn)圖像的風(fēng)格轉(zhuǎn)換。所謂的風(fēng)格轉(zhuǎn)換就是改變?cè)紙D片的風(fēng)格,如下圖左邊是原圖,中間是風(fēng)格圖(梵高畫),生成后是右邊的具有梵高風(fēng)格的原圖,可以看到總體上生成后的圖保留大部分原圖的內(nèi)容。

          2.1 cycleGAN簡(jiǎn)介

          cycleGAN本質(zhì)上和GAN是一樣的,是學(xué)習(xí)數(shù)據(jù)集中潛在的數(shù)據(jù)分布。GAN是從隨機(jī)噪聲生成同分布的圖片,cycleGAN是在有意義的圖上加上學(xué)習(xí)到的分布從而生成另一個(gè)領(lǐng)域的圖。cycleGAN假設(shè)image-to-image的兩個(gè)領(lǐng)域存在的潛在的聯(lián)系。
          眾所周知,GAN的映射函數(shù)很難保證生成圖片的有效性。cycleGAN利用cycle consistency來(lái)保證生成的圖片與輸入圖片的結(jié)構(gòu)上一致性。我們看下cycleGAN的結(jié)構(gòu):

          特點(diǎn)總結(jié)如下:
          • 兩路GAN:兩個(gè)生成器[ G:X->Y , F:Y->X ]  和兩個(gè)判別器[Dx, Dy], G和Dy目的是生成的對(duì)象,Dy(正類是Y領(lǐng)域)無(wú)法判別。同理F和Dx也是一樣的。
          • cycle consistency:G是生成Y的生成器, F是生成X的生成器,cycle consistency是為了約束G和F生成的對(duì)象的范圍,  是的G生成的對(duì)象通過(guò)F生成器能夠回到原始的領(lǐng)域如:x->G(x)->F(G(x))=x
          對(duì)抗loss如下:

          2.2 實(shí)現(xiàn)cycleGAN

          2.2.1 生成器

          從上面簡(jiǎn)介中生成器有兩個(gè)生成器,一個(gè)是正向,一個(gè)是反向的。結(jié)構(gòu)是參考論文Perceptual Losses for Real-Time Style Transfer and Super-Resolution: Supplementary Material。大致可以分為:下采樣 + residual 殘差block + 上采樣,如下圖(摘自論文):

          實(shí)現(xiàn)上下采樣是stride=2的卷積, 上采樣用nn.Upsample:
          # 殘差block
          class ResidualBlock(nn.Module):

              def __init__(self, in_features):
                  super(ResidualBlock, self).__init__()

                  self.block = nn.Sequential(
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features),
                      nn.ReLU(inplace=True),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features),
                  )

              def forward(self, x):
                  return x + self.block(x)

          class GeneratorResNet(nn.Module):
              def __init__(self, input_shape, num_residual_blocks):
                  super(GeneratorResNet, self).__init__()

                  channels = input_shape[0]

                  # Initial convolution block
                  out_features = 64
                  model = [
                      nn.ReflectionPad2d(channels),
                      nn.Conv2d(channels, out_features, 7),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True),
                  ]
                  in_features = out_features

                  # Downsampling
                  for _ in range(2):
                      out_features *= 2
                      model += [
                          nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                          nn.InstanceNorm2d(out_features),
                          nn.ReLU(inplace=True),
                      ]
                      in_features = out_features

                  # Residual blocks
                  for _ in range(num_residual_blocks):
                      model += [ResidualBlock(out_features)]

                  # Upsampling
                  for _ in range(2):
                      out_features //= 2
                      model += [
                          nn.Upsample(scale_factor=2),
                          nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                          nn.InstanceNorm2d(out_features),
                          nn.ReLU(inplace=True),
                      ]
                      in_features = out_features

                  # Output layer
                  model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]

                  self.model = nn.Sequential(*model)

              def forward(self, x):
                  return self.model(x)

          2.2.2 判別器

          傳統(tǒng)的GAN 判別器輸出的是一個(gè)值,判斷真假的程度。而patchGAN輸出是N*N值,每一個(gè)值代表著原始圖像上的一定大小的感受野,直觀上就是對(duì)原圖上crop下可重復(fù)的一部分區(qū)域進(jìn)行判斷真假,可以認(rèn)為是一個(gè)全卷積網(wǎng)絡(luò),最早是在pix2pix提出(Image-to-Image Translation with Conditional Adversarial Networks)。好處是參數(shù)少,另外一個(gè)從局部可以更好的抓取高頻信息。
          class Discriminator(nn.Module):
              def __init__(self, input_shape):
                  super(Discriminator, self).__init__()

                  channels, height, width = input_shape

                  # Calculate output shape of image discriminator (PatchGAN)
                  self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

                  def discriminator_block(in_filters, out_filters, normalize=True):
                      """Returns downsampling layers of each discriminator block"""
                      layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
                      if normalize:
                          layers.append(nn.InstanceNorm2d(out_filters))
                      layers.append(nn.LeakyReLU(0.2, inplace=True))
                      return layers

                  self.model = nn.Sequential(
                      *discriminator_block(channels, 64, normalize=False),
                      *discriminator_block(64128),
                      *discriminator_block(128256),
                      *discriminator_block(256512),
                      nn.ZeroPad2d((1010)),
                      nn.Conv2d(51214, padding=1)
                  )

              def forward(self, img):
                  return self.model(img)

          2.2.3 訓(xùn)練

          loss和模型初始化
          # Losses
          criterion_GAN = torch.nn.MSELoss()
          criterion_cycle = torch.nn.L1Loss()
          criterion_identity = torch.nn.L1Loss()

          cuda = torch.cuda.is_available()
          input_shape = (opt.channels, opt.img_height, opt.img_width)

          # Initialize generator and discriminator
          G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
          G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
          D_A = Discriminator(input_shape)
          D_B = Discriminator(input_shape)
          優(yōu)化器和訓(xùn)練策略
          # Optimizers
          optimizer_G = torch.optim.Adam(
              itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
          )
          optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
          optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

          # Learning rate update schedulers
          lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
              optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
          )
          lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
              optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
          )
          lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
              optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
          )
          訓(xùn)練迭代
          • 訓(xùn)練數(shù)據(jù)是成對(duì)的數(shù)據(jù),但是是非配對(duì)的數(shù)據(jù),即A和B是沒有直接的聯(lián)系的。A是原圖,B是風(fēng)格圖
          • 生成器訓(xùn)練
            • GAN loss:判別器判別A和B生成的兩個(gè)圖fake_A、fake_B與GT的loss
            • Cycle loss:反過(guò)來(lái)fake_A和fake_B 生成的圖與A和B像素上差異
          • 判別器訓(xùn)練:
            • loss_real: 判別A/B和GT的MSELoss
            • loss_fake:判別生成的fake_A/fake_B與GT的MSELoss
          for epoch in range(opt.epoch, opt.n_epochs):
              for i, batch in enumerate(dataloader):

                  # 數(shù)據(jù)是成對(duì)的數(shù)據(jù),但是是非配對(duì)的數(shù)據(jù),即A和B是沒有直接的聯(lián)系的
                  real_A = Variable(batch["A"].type(Tensor))
                  real_B = Variable(batch["B"].type(Tensor))

                  # Adversarial ground truths
                  valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
                  fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

                  # ------------------
                  #  Train Generators
                  # ------------------

                  G_AB.train()
                  G_BA.train()

                  optimizer_G.zero_grad()

                  # Identity loss
                  loss_id_A = criterion_identity(G_BA(real_A), real_A)
                  loss_id_B = criterion_identity(G_AB(real_B), real_B)

                  loss_identity = (loss_id_A + loss_id_B) / 2

                  # GAN loss
                  fake_B = G_AB(real_A)
                  loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
                  fake_A = G_BA(real_B)
                  loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

                  loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

                  # Cycle loss
                  recov_A = G_BA(fake_B)
                  loss_cycle_A = criterion_cycle(recov_A, real_A)
                  recov_B = G_AB(fake_A)
                  loss_cycle_B = criterion_cycle(recov_B, real_B)

                  loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

                  # Total loss
                  loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

                  loss_G.backward()
                  optimizer_G.step()

                  # -----------------------
                  #  Train Discriminator A
                  # -----------------------

                  optimizer_D_A.zero_grad()

                  # Real loss
                  loss_real = criterion_GAN(D_A(real_A), valid)
                  # Fake loss (on batch of previously generated samples)
                  # fake_A_ = fake_A_buffer.push_and_pop(fake_A)
                  loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
                  # Total loss
                  loss_D_A = (loss_real + loss_fake) / 2

                  loss_D_A.backward()
                  optimizer_D_A.step()

                  # -----------------------
                  #  Train Discriminator B
                  # -----------------------

                  optimizer_D_B.zero_grad()

                  # Real loss
                  loss_real = criterion_GAN(D_B(real_B), valid)
                  # Fake loss (on batch of previously generated samples)
                  # fake_B_ = fake_B_buffer.push_and_pop(fake_B)
                  loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
                  # Total loss
                  loss_D_B = (loss_real + loss_fake) / 2

                  loss_D_B.backward()
                  optimizer_D_B.step()

                  loss_D = (loss_D_A + loss_D_B) / 2

                  # --------------
                  #  Log Progress
                  # --------------

                  # Determine approximate time left
                  batches_done = epoch * len(dataloader) + i
                  batches_left = opt.n_epochs * len(dataloader) - batches_done
                  time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
                  prev_time = time.time()

              # Update learning rates
              lr_scheduler_G.step()
              lr_scheduler_D_A.step()
              lr_scheduler_D_B.step()

          2.2.4 結(jié)果展示

          本文訓(xùn)練的是莫奈風(fēng)格的轉(zhuǎn)變,如下圖:第一二行是莫奈風(fēng)格畫轉(zhuǎn)換為普通照片,第三四行為普通照片轉(zhuǎn)換為莫奈風(fēng)格畫

          再來(lái)看實(shí)際手機(jī)拍攝圖片:

          2.2.5 cycleGAN其他用途

          3. 總結(jié)

          本文詳細(xì)介紹了GAN的其中一種應(yīng)用cycleGAN,并將它應(yīng)用到圖像風(fēng)格的轉(zhuǎn)換??偨Y(jié)如下:
          • GAN是學(xué)習(xí)數(shù)據(jù)中分布,并生成同樣分布但全新的數(shù)據(jù)
          • CycleGAN是兩路GAN:兩個(gè)生成器和兩個(gè)判別器;為了保證生成器的生成的圖片與輸入圖存在一定的關(guān)系,不是隨機(jī)生產(chǎn)的圖片, 引入cycle consistency,判定A->fake_B->recove_A和A的差異
          • 生成器:下采樣 + residual 殘差block + 上采樣
          • 判別器: 不是一個(gè)圖生成一個(gè)判定值,而是patchGAN方式,生成很N*N個(gè)值,而后取均值


          作者簡(jiǎn)介:wedo實(shí)驗(yàn)君, 數(shù)據(jù)分析師;熱愛生活,熱愛寫作


          贊 賞 作 者


          更多閱讀



          谷歌 AI 團(tuán)隊(duì)用 GAN 模型合成異形生物體


          英偉達(dá)研究出用較少數(shù)據(jù)集訓(xùn)練GAN的方法


          Python 中圖像標(biāo)題生成的注意力機(jī)制實(shí)戰(zhàn)

          特別推薦




          點(diǎn)擊下方閱讀原文加入社區(qū)會(huì)員

          瀏覽 101
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  999无码| 91操屄视频 | 青娱乐国产精品视频 | 韩国精品三级 | 色婷婷在线视频网站 |