<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          人物屬性模型移動(dòng)端實(shí)驗(yàn)記錄

          共 17514字,需瀏覽 36分鐘

           ·

          2021-03-18 22:47

          【GiantPandaCV導(dǎo)語(yǔ)】最近項(xiàng)目有需求,需要把人物屬性用在移動(dòng)端上,需要輸出性別,顏值和年齡三個(gè)維度的標(biāo)簽, 用來(lái)做數(shù)據(jù)分析收集使用,對(duì)速度和精度有一定的需求,做了一些實(shí)驗(yàn),記錄如下。

          一、模型

          模型結(jié)構(gòu),這里考慮了兩種形式,一種是多頭的,一種是單頭的,具體如下:

          • SingleHead
            1. backbone+avgpool后面 接一個(gè)卷積,卷積核為(inp, (gender_class+beauty_class+age_class), 3, 3)
            2. backbone+avgpool后面 接入一個(gè)channel shuff層, 再接入一個(gè)卷積,和第一種一樣。
          • MutilHead
            1. backbone+avgpool后面,接入三個(gè)FC,每個(gè)FC對(duì)應(yīng)一個(gè)維度的任務(wù)。
            2. backbone+avgpool后面,先接入一個(gè)SE模塊后,接三個(gè)FC,每個(gè)FC對(duì)應(yīng)一個(gè)維度的任務(wù)。
            3. backbone+avgpool后面,接入一個(gè)512維度的FC,后接入三個(gè)FC,每個(gè)FC對(duì)應(yīng)一個(gè)維度的任務(wù)。
            4. backbone+avgpool后面,接入三個(gè)512維度的FC來(lái)做embeeding,后接入三個(gè)FC,每個(gè)FC對(duì)應(yīng)一個(gè)維度的任務(wù)。

          如下圖所示:

             
             

          圖1-不同模型結(jié)構(gòu)

          訓(xùn)練, 訓(xùn)練數(shù)據(jù)總計(jì)35w,每張圖片都帶有三個(gè)維度的標(biāo)簽,使用Horovod分布式框架進(jìn)行訓(xùn)練,采用SGD優(yōu)化器,warmup5個(gè)epoch,使用cosine進(jìn)行衰減學(xué)習(xí)率,總計(jì)訓(xùn)練60個(gè)epoch,訓(xùn)練代碼可以參考https://github.com/FlyEgle/cub_baseline。

          實(shí)驗(yàn)對(duì)比,對(duì)于SingleHead模型,MutilHead的1,2模型,采用的是mobilenetv2作為backbone,對(duì)于MutilHead的3,4模型,采用的是mobilenetv2x0.5作為backbone。這里對(duì)比的baseline為resnest50的結(jié)果,結(jié)果如下:

             
             

          圖2-結(jié)果對(duì)比

          結(jié)論,出于性能和速度的考慮,確定了以mbv2x0.5作為backbone,模型結(jié)構(gòu)為mutilhead-4的模型。

          模型SIZEFLOPsPARAMsgender_accbeauty_accage_acc
          baseline(rs50)2565.7G31M0.9709821430.8973214290.790178571
          mbv2x0.5(mutil_head)256127M2.66M0.9040178570.8348214290.725446429

          二、蒸餾

          mobilenetv2與resnest50在imagenet上的baseline大概相差8個(gè)點(diǎn)左右,所以我們自身的實(shí)驗(yàn)跑出來(lái)的結(jié)果也是在合理的范圍內(nèi)。為了進(jìn)一步提升小模型的精度,選擇用resnest50的模型來(lái)蒸餾mbv2x0.5的模型(ps:這里嘗試過(guò)訓(xùn)練一個(gè)mbv2x2的模型,不過(guò)沒(méi)有訓(xùn)的比resnest50高,所以還是使用resnest50)。蒸餾,采用的是傳統(tǒng)的蒸餾方法,KL散度來(lái)作為損失,由于head相同,所以只需要考慮對(duì)logits蒸餾即可,KL散度代碼如下:

          class KLSoftLoss(nn.Module):
              r"""Apply softtarget for kl loss

              Arguments:
                  reduction (str): "
          batchmean" for the mean loss with the p(x)*(log(p(x)) - log(q(x)))
              "
          ""
              def __init__(self, temperature=1, reduction="batchmean"):
                  super(KLSoftLoss, self).__init__()
                  self.reduction = reduction
                  self.eps = 1e-7
                  self.temperature = temperature
                  self.klloss = nn.KLDivLoss(reduction=self.reduction)

              def forward(self, s_logits, t_logits):
                  s_prob = F.log_softmax(s_logits / self.temperature, 1)
                  t_prob = F.softmax(t_logits / self.temperature, 1)
                  loss = self.klloss(s_prob, t_prob) * self.temperature * self.temperature
                  return loss

          訓(xùn)練, 對(duì)于分類的問(wèn)題,一般情況只是蒸餾輸出的logits即可,由于多任務(wù)有多個(gè)head,所以會(huì)有多個(gè)logits,分別蒸餾即可,整體框架如下:

             
             

          圖4-蒸餾訓(xùn)練框架

          蒸餾訓(xùn)練代碼如下,由于學(xué)生和教師的網(wǎng)絡(luò)差異性較大同時(shí)精度相差甚遠(yuǎn),所以采用1:1的比例來(lái)進(jìn)行訓(xùn)練,蒸餾的溫度為25(T=5):

             
             

          圖5-蒸餾訓(xùn)練代碼

          結(jié)論,采用了3中不同的分辨率進(jìn)行蒸餾實(shí)驗(yàn),其中訓(xùn)練的size為224,推理為256的時(shí)候效果最好。

          模型sizeteachergender_accbeauty_accage_acc
          mbv2x0.5224->256resnest500.9665178570.895089290.75446429
          mbv2x0.5192->224resnest500.9508928570.897321430.765625
          mbv2x0.5160->224resnest500.9598214290.8906250.734375

          三、剪枝

          Slimming Prune,實(shí)驗(yàn)采用的剪枝方法是來(lái)自于Learning Efficient Convolutional Networks through Network Slimming,通過(guò)對(duì)BN的channel進(jìn)行稀疏化來(lái)達(dá)到剪枝的效果(個(gè)人喜歡用比較簡(jiǎn)單穩(wěn)定的方法,便于debug和修改)。

             
             

          圖5-Slimming

          訓(xùn)練和剪枝

          • 訓(xùn)練,訓(xùn)練代碼很簡(jiǎn)單,只需要再更新權(quán)重之前進(jìn)行稀疏化處理即可,sr是超參,一般設(shè)置為1e-4,代碼如下:

              optimizer.zero_grad()
              loss.backward()

              # use the slimming prune for training
              if args.prune and args.use_sr:
                  for m in model.modules():
                      if isinstance(m, nn.BatchNorm2d):
                          m.weight.grad.data.add_(args.sr * torch.sign(m.weight.data))

              optimizer.step()
          • 剪枝, 由于模型結(jié)構(gòu)是mobilenetv2的結(jié)構(gòu),有DW存在,所以,在剪枝的時(shí)候需要注意groups的數(shù)量和channel需要保持一致,同時(shí),為了方便移動(dòng)端優(yōu)化加速,要保證channel是8的倍數(shù),剪枝代碼邏輯如下:

            1. 先設(shè)置一定的剪枝比例p,如0.1,0.2,0.3...,按BN的channel總數(shù)從小到大來(lái)進(jìn)行過(guò)濾。
            2. 保留最大比例的最小閾值,防止prune過(guò)大,導(dǎo)致模型崩潰。
            3. 對(duì)于不滿足8的倍數(shù)的channel,按8的倍數(shù)補(bǔ)齊,補(bǔ)齊的方法是對(duì)prune過(guò)的channel排序,從大到小按差值補(bǔ)齊。
            4. 保存除了第一個(gè)InvertedResidual模塊以外的所有模塊剪枝后的channel數(shù)量,重構(gòu)模型。
            5. 測(cè)試結(jié)果,考慮是否進(jìn)行finetune訓(xùn)練。

          剪枝部分代碼如下:

          def prune_only_res_hidden(percent, model, keep_channel=True, channel_ratio=8, cuda=True):
              """only prune the inverResidual module first bn layer
              "
          ""
              total = 0
              highest_thre = []
              for m in model.modules():
                  if isinstance(m, InvertedResidual):
                      # only prune the 3 conv layer
                      if len(m.conv) > 5:
                          for i in range(len(m.conv)):
                              if i == 1:
                                  if isinstance(m.conv[i], nn.BatchNorm2d):
                                      total += m.conv[i].weight.data.shape[0]
                                      highest_thre.append(m.conv[i].weight.data.abs().max().item())
                                      total += m.conv[i+3].weight.data.shape[0]
                                      highest_thre.append(m.conv[i+3].weight.data.abs().max().item())



              bn = torch.zeros(total)
              index = 0
              for m in model.modules():
                  if isinstance(m, InvertedResidual):
                      # only prune the 3 conv layer
                      if len(m.conv) > 5:
                          for i in range(len(m.conv)):
                              if i != len(m.conv) - 1:
                                  if isinstance(m.conv[i], nn.BatchNorm2d):
                                      size = m.conv[i].weight.data.shape[0]
                                      bn[index:(index+size)] = m.conv[i].weight.data.abs().clone()
                                      index += size

              print(bn.size())
              y, i = torch.sort(bn)
              thre_index = int(total * percent)
              thre = y[thre_index]
              highest_thre = min(highest_thre)

              # 判斷閾值
              if thre > highest_thre:
                  thre = highest_thre

              print("the min thre is {}, the max thre is {}!!!!".format(thre, highest_thre))
              pruned = 0
              c = {}
              cfg_mask = []
              idx = 0

              for m in model.modules():
                  if isinstance(m, InvertedResidual):
                      # only prune the 3 conv layer
                      if len(m.conv) > 5:
                          for i in range(len(m.conv)):
                              if i == 1:
                                  if isinstance(m.conv[i], nn.BatchNorm2d):
                                      weight_copy = m.conv[i].weight.data.clone()
                                      if cuda:
                                          mask = weight_copy.abs().gt(thre).float().cuda()
                                      else:
                                          mask = weight_copy.abs().gt(thre).float()

                                      if keep_channel:
                                          keep_channel_number = get_min_number(torch.sum(mask), channel_ratio)
                                          if torch.sum(mask) < keep_channel_number:
                                              n = int(keep_channel_number - torch.sum(mask))
                                              mask_index = torch.where(mask==0)[0]
                                              new_weight = weight_copy.abs()[mask_index]
                                              _, weight_index = torch.sort(new_weight)
                                              w_index = mask_index[weight_index[-n: ]]
                                              mask[w_index] = 1.0

                                      pruned = pruned + mask.shape[0] - torch.sum(mask)
                                      # first conv + bn
                                      m.conv[i].weight.data.mul_(mask)
                                      m.conv[i].bias.data.mul_(mask)
                                      # second conv + bn
                                      m.conv[i+3].weight.data.mul_(mask)
                                      m.conv[i+3].bias.data.mul_(mask)
                                      c[idx] = int(torch.sum(mask))
                                      cfg_mask.append(mask.clone())

                                      print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(idx, mask.shape[0], int(torch.sum(mask))))
                                      idx += 1
              print(c)
              print(len(c))
              print(len(cfg_mask))
              # pruned_ratio = pruned / total
              print('Pre-processing Successful!!!')
              return model, cfg_mask, c

          直接保存模型后測(cè)試,對(duì)比結(jié)果如下:

          模型ratioFLOPsParamsgenderbeautyage
          mobilenetv2x0.50.24111.95M2.54M0.9575892860.8928571430.741071429
          mobilenetv2x0.50.3107.51M2.51M0.9598214290.8928571430.741071429
          mobilenetv2x0.50.479.57M2.46M0.6093750.5334821430.098214286
          mobilenetv2x0.50.579.56M2.46M0.6093750.5334821430.098214286

          再次使用resnest50進(jìn)行蒸餾后,對(duì)比結(jié)果如下:

          模型ratioFLOPsParamsgenderbeautyage
          mobilenetv2x0.50.24111.95M2.54M0.968750.9017857140.75
          mobilenetv2x0.50.3107.51M2.51M0.9575892860.8839285710.738839286
          mobilenetv2x0.50.479.57M2.46M0.9575892860.8816964290.741071429

          添加2w標(biāo)注的業(yè)務(wù)數(shù)據(jù),總計(jì)訓(xùn)練數(shù)據(jù)37w,蒸餾后的結(jié)果如下:

          模型ratioFLOPsParamsgenderbeautyage
          mobilenetv2x0.50.24111.95M2.54M0.9754464290.9017857140.752232143
          mobilenetv2x0.50.3107.51M2.51M0.968750.8906250.761160714
          mobilenetv2x0.50.479.57M2.46M0.9665178570.8928571430.723214286

          針對(duì)性能的需求,考慮用0.3的版本,如果速度要求更快的話,考慮0.4的版本。

          四、TODO

          1. 訓(xùn)練一個(gè)基于BYOL的pretrain模型。
          2. 把沒(méi)有標(biāo)注的數(shù)據(jù),用模型打上偽標(biāo)簽后參與訓(xùn)練。
          3. 訓(xùn)練一個(gè)更大的teacher模型。
          4. 使用百度的JSDivLoss作為蒸餾損失。

          五、結(jié)論

          • 對(duì)于移動(dòng)端的任務(wù)來(lái)說(shuō),蒸餾和剪枝是必不可少的,尤其是要去訓(xùn)練一個(gè)比較好的teacher,這里的teacher可以同結(jié)構(gòu)也可以異結(jié)構(gòu),只要最后logits一致即可。
          • 由于移動(dòng)端會(huì)根據(jù)X8或者X4的倍數(shù)優(yōu)化,所以剪枝的時(shí)候盡量保持channel的倍數(shù),建議常備一種便于修改的剪枝代碼。
          • 小模型具備成長(zhǎng)為大模型的潛質(zhì),只要訓(xùn)練方法適當(dāng)。

          結(jié)束語(yǔ)

          本人才疏學(xué)淺,以上都是自己在做項(xiàng)目中的一些方法和實(shí)驗(yàn),以及一些粗淺的思考,并不一定完全正確,只是個(gè)人的理解,歡迎大家指正,留言評(píng)論。

          參考文獻(xiàn)

          • mobilenetv2 https://export.arxiv.org/pdf/1801.04381
          • resnest https://export.arxiv.org/pdf/2004.08955
          • Slimming prune https://arxiv.org/pdf/1708.06519.pdf

          歡迎關(guān)注GiantPandaCV, 在這里你將看到獨(dú)家的深度學(xué)習(xí)分享,堅(jiān)持原創(chuàng),每天分享我們學(xué)習(xí)到的新鮮知識(shí)。( ? ?ω?? )?

          有對(duì)文章相關(guān)的問(wèn)題,或者想要加入交流群,歡迎添加BBuf微信:

          二維碼

          為了方便讀者獲取資料以及我們公眾號(hào)的作者發(fā)布一些Github工程的更新,我們成立了一個(gè)QQ群,二維碼如下,感興趣可以加入。

          公眾號(hào)QQ交流群


          瀏覽 92
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  狠狠se | 三级无码在线观看 | 678五月丁香亚洲 | 免费色黄视频 | 91精品国产综合久久蜜芽解析速度 |