<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          知識(shí)蒸餾綜述:代碼整理

          共 42100字,需瀏覽 85分鐘

           ·

          2022-11-02 00:57

          點(diǎn)擊上方小白學(xué)視覺(jué)”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)


          編者薦語(yǔ)

           

          收集自RepDistiller中的蒸餾方法,盡可能簡(jiǎn)單解釋蒸餾用到的策略,并提供了實(shí)現(xiàn)源碼。

          1. KD: Knowledge Distillation


          全稱:Distilling the Knowledge in a Neural Network

          鏈接:https://arxiv.org/pdf/1503.02531.pd3f

          發(fā)表:NIPS14

          最經(jīng)典的,也是明確提出知識(shí)蒸餾概念的工作,通過(guò)使用帶溫度的softmax函數(shù)來(lái)軟化教師網(wǎng)絡(luò)的邏輯層輸出作為學(xué)生網(wǎng)絡(luò)的監(jiān)督信息,

          使用KL divergence來(lái)衡量學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的差異,具體流程如下圖所示(來(lái)自Knowledge Distillation A Survey)

          對(duì)學(xué)生網(wǎng)絡(luò)來(lái)說(shuō),一部分監(jiān)督信息來(lái)自hard label標(biāo)簽,另一部分來(lái)自教師網(wǎng)絡(luò)提供的soft label。

          代碼實(shí)現(xiàn):

          class DistillKL(nn.Module):
              """Distilling the Knowledge in a Neural Network"""
              def __init__(self, T):
                  super(DistillKL, self).__init__()
                  self.T = T

              def forward(self, y_s, y_t):
                  p_s = F.log_softmax(y_s/self.T, dim=1)
                  p_t = F.softmax(y_t/self.T, dim=1)
                  loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
                  return loss

          核心就是一個(gè)kl_div函數(shù),用于計(jì)算學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的分布差異。

          2. FitNet: Hints for thin deep nets

          全稱:Fitnets: hints for thin deep nets

          鏈接:https://arxiv.org/pdf/1412.6550.pdf

          發(fā)表:ICLR 15 Poster

          對(duì)中間層進(jìn)行蒸餾的開山之作,通過(guò)將學(xué)生網(wǎng)絡(luò)的feature map擴(kuò)展到與教師網(wǎng)絡(luò)的feature map相同尺寸以后,使用均方誤差MSE Loss來(lái)衡量?jī)烧卟町悺?/p>

          實(shí)現(xiàn)如下:

          class HintLoss(nn.Module):
              """Fitnets: hints for thin deep nets, ICLR 2015"""
              def __init__(self):
                  super(HintLoss, self).__init__()
                  self.crit = nn.MSELoss()

              def forward(self, f_s, f_t):
                  loss = self.crit(f_s, f_t)
                  return loss

          實(shí)現(xiàn)核心就是MSELoss

          3. AT: Attention Transfer

          全稱:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer

          鏈接:https://arxiv.org/pdf/1612.03928.pdf

          發(fā)表:ICLR16

          為了提升學(xué)生模型性能提出使用注意力作為知識(shí)載體進(jìn)行遷移,文中提到了兩種注意力,一種是activation-based attention transfer,另一種是gradient-based attention transfer。實(shí)驗(yàn)發(fā)現(xiàn)第一種方法既簡(jiǎn)單效果又好。


          實(shí)現(xiàn)如下:

          class Attention(nn.Module):
              """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
              via Attention Transfer
              code: https://github.com/szagoruyko/attention-transfer"""

              def __init__(self, p=2):
                  super(Attention, self).__init__()
                  self.p = p

              def forward(self, g_s, g_t):
                  return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]

              def at_loss(self, f_s, f_t):
                  s_H, t_H = f_s.shape[2], f_t.shape[2]
                  if s_H > t_H:
                      f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
                  elif s_H < t_H:
                      f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
                  else:
                      pass
                  return (self.at(f_s) - self.at(f_t)).pow(2).mean()

              def at(self, f):
                  return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))

          首先使用avgpool將尺寸調(diào)整一致,然后使用MSE Loss來(lái)衡量?jī)烧卟罹唷?/p>

          4. SP: Similarity-Preserving

          全稱:Similarity-Preserving Knowledge Distillation

          鏈接:https://arxiv.org/pdf/1907.09682.pdf

          發(fā)表:ICCV19

          SP歸屬于基于關(guān)系的知識(shí)蒸餾方法。文章思想是提出相似性保留的知識(shí),使得教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)會(huì)對(duì)相同的樣本產(chǎn)生相似的激活。可以從下圖看出處理流程,教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)對(duì)應(yīng)feature map通過(guò)計(jì)算內(nèi)積,得到bsxbs的相似度矩陣,然后使用均方誤差來(lái)衡量?jī)蓚€(gè)相似度矩陣。

          最終Loss為:

          G代表的就是bsxbs的矩陣。

          實(shí)現(xiàn)如下:

          class Similarity(nn.Module):
              """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
              def __init__(self):
                  super(Similarity, self).__init__()

              def forward(self, g_s, g_t):
                  return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]

              def similarity_loss(self, f_s, f_t):
                  bsz = f_s.shape[0]
                  f_s = f_s.view(bsz, -1)
                  f_t = f_t.view(bsz, -1)

                  G_s = torch.mm(f_s, torch.t(f_s))
                  # G_s = G_s / G_s.norm(2)
                  G_s = torch.nn.functional.normalize(G_s)
                  G_t = torch.mm(f_t, torch.t(f_t))
                  # G_t = G_t / G_t.norm(2)
                  G_t = torch.nn.functional.normalize(G_t)

                  G_diff = G_t - G_s
                  loss = (G_diff * G_diff).view(-11).sum(0) / (bsz * bsz)
                  return loss

          5. CC: Correlation Congruence

          全稱:Correlation Congruence for Knowledge Distillation

          鏈接:https://arxiv.org/pdf/1904.01802.pdf

          發(fā)表:ICCV19

          CC也歸屬于基于關(guān)系的知識(shí)蒸餾方法。不應(yīng)該僅僅引導(dǎo)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)單個(gè)樣本向量之間的差異,還應(yīng)該學(xué)習(xí)兩個(gè)樣本之間的相關(guān)性,而這個(gè)相關(guān)性使用的是Correlation Congruence 教師網(wǎng)絡(luò)雨學(xué)生網(wǎng)絡(luò)相關(guān)性之間的歐氏距離。

          整體Loss如下:

          實(shí)現(xiàn)如下:

          class Correlation(nn.Module):
              """Similarity-preserving loss. My origianl own reimplementation 
              based on the paper before emailing the original authors."""

              def __init__(self):
                  super(Correlation, self).__init__()

              def forward(self, f_s, f_t):
                  return self.similarity_loss(f_s, f_t)

              def similarity_loss(self, f_s, f_t):
                  bsz = f_s.shape[0]
                  f_s = f_s.view(bsz, -1)
                  f_t = f_t.view(bsz, -1)

                  G_s = torch.mm(f_s, torch.t(f_s))
                  G_s = G_s / G_s.norm(2)
                  G_t = torch.mm(f_t, torch.t(f_t))
                  G_t = G_t / G_t.norm(2)

                  G_diff = G_t - G_s
                  loss = (G_diff * G_diff).view(-11).sum(0) / (bsz * bsz)
                  return loss

          6. VID: Variational Information Distillation

          全稱:Variational Information Distillation for Knowledge Transfer

          鏈接:https://arxiv.org/pdf/1904.05835.pdf

          發(fā)表:CVPR19

          利用互信息(Mutual Information)來(lái)衡量學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)差異?;バ畔⒖梢员硎境鰞蓚€(gè)變量的互相依賴程度,其值越大,表示變量之間的依賴程度越高。互信息計(jì)算如下:

          互信息是教師模型的熵減去在已知學(xué)生模型條件下教師模型的熵。目標(biāo)是最大化互信息,因?yàn)榛バ畔⒃酱笳f(shuō)明H(t|s)越小,即學(xué)生網(wǎng)絡(luò)確定的情況下,教師網(wǎng)絡(luò)的熵會(huì)變小,證明學(xué)生網(wǎng)絡(luò)已經(jīng)學(xué)習(xí)的比較充分。

          整體loss如下:

          由于p(t|s)很難計(jì)算,可以使用變分分布q(t|s)去接近真實(shí)分布。

          其中q(t|s)是使用方差可學(xué)習(xí)的高斯分布模擬(公式中的log_scale):

          實(shí)現(xiàn)如下:

          class VIDLoss(nn.Module):
              """Variational Information Distillation for Knowledge Transfer (CVPR 2019),
              code from author: https://github.com/ssahn0215/variational-information-distillation"""

              def __init__(self,
                           num_input_channels,
                           num_mid_channel,
                           num_target_channels,
                           init_pred_var=5.0,
                           eps=1e-5)
          :

                  super(VIDLoss, self).__init__()

                  def conv1x1(in_channels, out_channels, stride=1):
                      return nn.Conv2d(
                          in_channels, out_channels,
                          kernel_size=1, padding=0,
                          bias=False, stride=stride)

                  self.regressor = nn.Sequential(
                      conv1x1(num_input_channels, num_mid_channel),
                      nn.ReLU(),
                      conv1x1(num_mid_channel, num_mid_channel),
                      nn.ReLU(),
                      conv1x1(num_mid_channel, num_target_channels),
                  )
                  self.log_scale = torch.nn.Parameter(
                      np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
                      )
                  self.eps = eps

              def forward(self, input, target):
                  # pool for dimentsion match
                  s_H, t_H = input.shape[2], target.shape[2]
                  if s_H > t_H:
                      input = F.adaptive_avg_pool2d(input, (t_H, t_H))
                  elif s_H < t_H:
                      target = F.adaptive_avg_pool2d(target, (s_H, s_H))
                  else:
                      pass
                  pred_mean = self.regressor(input)
                  pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
                  pred_var = pred_var.view(1-111)
                  neg_log_prob = 0.5*(
                      (pred_mean-target)**2/pred_var+torch.log(pred_var)
                      )
                  loss = torch.mean(neg_log_prob)
                  return loss

          7. RKD: Relation Knowledge Distillation

          全稱:Relational Knowledge Disitllation

          鏈接:http://arxiv.org/pdf/1904.05068

          發(fā)表:CVPR19

          RKD也是基于關(guān)系的知識(shí)蒸餾方法,RKD提出了兩種損失函數(shù),二階的距離損失和三階的角度損失。

          • Distance-wise Loss
          • Angle-wise Loss

          實(shí)現(xiàn)如下:

          class RKDLoss(nn.Module):
              """Relational Knowledge Disitllation, CVPR2019"""
              def __init__(self, w_d=25, w_a=50):
                  super(RKDLoss, self).__init__()
                  self.w_d = w_d
                  self.w_a = w_a

              def forward(self, f_s, f_t):
                  student = f_s.view(f_s.shape[0], -1)
                  teacher = f_t.view(f_t.shape[0], -1)

                  # RKD distance loss
                  with torch.no_grad():
                      t_d = self.pdist(teacher, squared=False)
                      mean_td = t_d[t_d > 0].mean()
                      t_d = t_d / mean_td

                  d = self.pdist(student, squared=False)
                  mean_d = d[d > 0].mean()
                  d = d / mean_d

                  loss_d = F.smooth_l1_loss(d, t_d)

                  # RKD Angle loss
                  with torch.no_grad():
                      td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
                      norm_td = F.normalize(td, p=2, dim=2)
                      t_angle = torch.bmm(norm_td, norm_td.transpose(12)).view(-1)

                  sd = (student.unsqueeze(0) - student.unsqueeze(1))
                  norm_sd = F.normalize(sd, p=2, dim=2)
                  s_angle = torch.bmm(norm_sd, norm_sd.transpose(12)).view(-1)

                  loss_a = F.smooth_l1_loss(s_angle, t_angle)

                  loss = self.w_d * loss_d + self.w_a * loss_a

                  return loss

              @staticmethod
              def pdist(e, squared=False, eps=1e-12):
                  e_square = e.pow(2).sum(dim=1)
                  prod = e @ e.t()
                  res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

                  if not squared:
                      res = res.sqrt()

                  res = res.clone()
                  res[range(len(e)), range(len(e))] = 0
                  return res

          8. PKT:Probabilistic Knowledge Transfer

          全稱:Probabilistic Knowledge Transfer for deep representation learning

          鏈接:https://arxiv.org/abs/1803.10837

          發(fā)表:CoRR18

          提出一種概率知識(shí)轉(zhuǎn)移方法,引入了互信息來(lái)進(jìn)行建模。該方法具有可跨模態(tài)知識(shí)轉(zhuǎn)移、無(wú)需考慮任務(wù)類型、可將手工特征融入網(wǎng)絡(luò)等有點(diǎn)。

          實(shí)現(xiàn)如下:

          class PKT(nn.Module):
              """Probabilistic Knowledge Transfer for deep representation learning
              Code from author: https://github.com/passalis/probabilistic_kt"""

              def __init__(self):
                  super(PKT, self).__init__()

              def forward(self, f_s, f_t):
                  return self.cosine_similarity_loss(f_s, f_t)

              @staticmethod
              def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
                  # Normalize each vector by its norm
                  output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
                  output_net = output_net / (output_net_norm + eps)
                  output_net[output_net != output_net] = 0

                  target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
                  target_net = target_net / (target_net_norm + eps)
                  target_net[target_net != target_net] = 0

                  # Calculate the cosine similarity
                  model_similarity = torch.mm(output_net, output_net.transpose(01))
                  target_similarity = torch.mm(target_net, target_net.transpose(01))

                  # Scale cosine similarity to 0..1
                  model_similarity = (model_similarity + 1.0) / 2.0
                  target_similarity = (target_similarity + 1.0) / 2.0

                  # Transform them into probabilities
                  model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
                  target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)

                  # Calculate the KL-divergence
                  loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))

                  return loss

          9. AB: Activation Boundaries

          全稱:Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons

          鏈接:https://arxiv.org/pdf/1811.03233.pdf

          發(fā)表:AAAI18

          目標(biāo):讓教師網(wǎng)絡(luò)層的神經(jīng)元的激活邊界盡量和學(xué)生網(wǎng)絡(luò)的一樣。所謂的激活邊界指的是分離超平面(針對(duì)的是RELU這種激活函數(shù)),其決定了神經(jīng)元的激活與失活。AB提出的激活轉(zhuǎn)移損失,讓教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)之間的分離邊界盡可能一致。

          實(shí)現(xiàn)如下:

          class ABLoss(nn.Module):
              """Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
              code: https://github.com/bhheo/AB_distillation
              """

              def __init__(self, feat_num, margin=1.0):
                  super(ABLoss, self).__init__()
                  self.w = [2**(i-feat_num+1for i in range(feat_num)]
                  self.margin = margin

              def forward(self, g_s, g_t):
                  bsz = g_s[0].shape[0]
                  losses = [self.criterion_alternative_l2(s, t) for s, t in zip(g_s, g_t)]
                  losses = [w * l for w, l in zip(self.w, losses)] 
                  # loss = sum(losses) / bsz
                  # loss = loss / 1000 * 3
                  losses = [l / bsz for l in losses]
                  losses = [l / 1000 * 3 for l in losses]
                  return losses

              def criterion_alternative_l2(self, source, target):
                  loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() +
                          (source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float())
                  return torch.abs(loss).sum()

          10. FT: Factor Transfer

          全稱:Paraphrasing Complex Network: Network Compression via Factor Transfer

          鏈接:https://arxiv.org/pdf/1802.04977.pdf

          發(fā)表:NIPS18

          提出的是factor transfer的方法。所謂的factor,其實(shí)是對(duì)模型最后的數(shù)據(jù)結(jié)果進(jìn)行一個(gè)編解碼的過(guò)程,提取出的一個(gè)factor矩陣,用教師網(wǎng)絡(luò)的factor來(lái)指導(dǎo)學(xué)生網(wǎng)絡(luò)的factor。

          FT計(jì)算公式為:

          實(shí)現(xiàn)如下:

          class FactorTransfer(nn.Module):
              """Paraphrasing Complex Network: Network Compression via Factor Transfer, NeurIPS 2018"""
              def __init__(self, p1=2, p2=1):
                  super(FactorTransfer, self).__init__()
                  self.p1 = p1
                  self.p2 = p2

              def forward(self, f_s, f_t):
                  return self.factor_loss(f_s, f_t)

              def factor_loss(self, f_s, f_t):
                  s_H, t_H = f_s.shape[2], f_t.shape[2]
                  if s_H > t_H:
                      f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
                  elif s_H < t_H:
                      f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
                  else:
                      pass
                  if self.p2 == 1:
                      return (self.factor(f_s) - self.factor(f_t)).abs().mean()
                  else:
                      return (self.factor(f_s) - self.factor(f_t)).pow(self.p2).mean()

              def factor(self, f):
                  return F.normalize(f.pow(self.p1).mean(1).view(f.size(0), -1))

          11. FSP: Flow of Solution Procedure

          全稱:A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning

          鏈接:https://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf

          發(fā)表:CVPR17

          FSP認(rèn)為教學(xué)生網(wǎng)絡(luò)不同層輸出的feature之間的關(guān)系比教學(xué)生網(wǎng)絡(luò)結(jié)果好

          定義了FSP矩陣來(lái)定義網(wǎng)絡(luò)內(nèi)部特征層之間的關(guān)系,是一個(gè)Gram矩陣反映老師教學(xué)生的過(guò)程。

          使用的是L2 Loss進(jìn)行約束FSP矩陣。

          實(shí)現(xiàn)如下:

          class FSP(nn.Module):
              """A Gift from Knowledge Distillation:
              Fast Optimization, Network Minimization and Transfer Learning"""

              def __init__(self, s_shapes, t_shapes):
                  super(FSP, self).__init__()
                  assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
                  s_c = [s[1for s in s_shapes]
                  t_c = [t[1for t in t_shapes]
                  if np.any(np.asarray(s_c) != np.asarray(t_c)):
                      raise ValueError('num of channels not equal (error in FSP)')

              def forward(self, g_s, g_t):
                  s_fsp = self.compute_fsp(g_s)
                  t_fsp = self.compute_fsp(g_t)
                  loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
                  return loss_group

              @staticmethod
              def compute_loss(s, t):
                  return (s - t).pow(2).mean()

              @staticmethod
              def compute_fsp(g):
                  fsp_list = []
                  for i in range(len(g) - 1):
                      bot, top = g[i], g[i + 1]
                      b_H, t_H = bot.shape[2], top.shape[2]
                      if b_H > t_H:
                          bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
                      elif b_H < t_H:
                          top = F.adaptive_avg_pool2d(top, (b_H, b_H))
                      else:
                          pass
                      bot = bot.unsqueeze(1)
                      top = top.unsqueeze(2)
                      bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)
                      top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)

                      fsp = (bot * top).mean(-1)
                      fsp_list.append(fsp)
                  return fsp_list

          12. NST: Neuron Selectivity Transfer

          全稱:Like what you like: knowledge distill via neuron selectivity transfer

          鏈接:https://arxiv.org/pdf/1707.01219.pdf

          發(fā)表:CoRR17

          使用新的損失函數(shù)最小化教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)之間的Maximum Mean Discrepancy(MMD), 文中選擇的是對(duì)其教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)之間神經(jīng)元選擇樣式的分布。

          使用核技巧(對(duì)應(yīng)下面poly kernel)并進(jìn)一步展開以后可得:

          實(shí)際上提供了Linear Kernel、Poly Kernel、Gaussian Kernel三種,這里實(shí)現(xiàn)只給了Poly這種,這是因?yàn)镻oly這種方法可以與KD進(jìn)行互補(bǔ),這樣整體效果會(huì)非常好。

          實(shí)現(xiàn)如下:

          class NSTLoss(nn.Module):
              """like what you like: knowledge distill via neuron selectivity transfer"""
              def __init__(self):
                  super(NSTLoss, self).__init__()
                  pass

              def forward(self, g_s, g_t):
                  return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]

              def nst_loss(self, f_s, f_t):
                  s_H, t_H = f_s.shape[2], f_t.shape[2]
                  if s_H > t_H:
                      f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
                  elif s_H < t_H:
                      f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
                  else:
                      pass

                  f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
                  f_s = F.normalize(f_s, dim=2)
                  f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
                  f_t = F.normalize(f_t, dim=2)

                  # set full_loss as False to avoid unnecessary computation
                  full_loss = True
                  if full_loss:
                      return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
                              - 2 * self.poly_kernel(f_s, f_t).mean())
                  else:
                      return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()

              def poly_kernel(self, a, b):
                  a = a.unsqueeze(1)
                  b = b.unsqueeze(2)
                  res = (a * b).sum(-1).pow(2)
                  return res

          13. CRD: Contrastive Representation Distillation

          全稱:Contrastive Representation Distillation

          鏈接:https://arxiv.org/abs/1910.10699v2

          發(fā)表:ICLR20

          將對(duì)比學(xué)習(xí)引入知識(shí)蒸餾中,其目標(biāo)修正為:學(xué)習(xí)一個(gè)表征,讓正樣本對(duì)的教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)盡可能接近,負(fù)樣本對(duì)教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)盡可能遠(yuǎn)離。

          構(gòu)建的對(duì)比學(xué)習(xí)問(wèn)題表示如下:

          整體的蒸餾Loss表示如下:

          實(shí)現(xiàn)如下:https://github.com/HobbitLong/RepDistiller

          class ContrastLoss(nn.Module):
              """
              contrastive loss, corresponding to Eq (18)
              """

              def __init__(self, n_data):
                  super(ContrastLoss, self).__init__()
                  self.n_data = n_data

              def forward(self, x):
                  bsz = x.shape[0]
                  m = x.size(1) - 1

                  # noise distribution
                  Pn = 1 / float(self.n_data)

                  # loss for positive pair
                  P_pos = x.select(10)
                  log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()

                  # loss for K negative pair
                  P_neg = x.narrow(11, m)
                  log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()

                  loss = - (log_D1.sum(0) + log_D0.view(-11).sum(0)) / bsz

                  return loss
                  
          class CRDLoss(nn.Module):
              """CRD Loss function
              includes two symmetric parts:
              (a) using teacher as anchor, choose positive and negatives over the student side
              (b) using student as anchor, choose positive and negatives over the teacher side

              Args:
                  opt.s_dim: the dimension of student's feature
                  opt.t_dim: the dimension of teacher's feature
                  opt.feat_dim: the dimension of the projection space
                  opt.nce_k: number of negatives paired with each positive
                  opt.nce_t: the temperature
                  opt.nce_m: the momentum for updating the memory buffer
                  opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
              """

              def __init__(self, opt):
                  super(CRDLoss, self).__init__()
                  self.embed_s = Embed(opt.s_dim, opt.feat_dim)
                  self.embed_t = Embed(opt.t_dim, opt.feat_dim)
                  self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
                  self.criterion_t = ContrastLoss(opt.n_data)
                  self.criterion_s = ContrastLoss(opt.n_data)

              def forward(self, f_s, f_t, idx, contrast_idx=None):
                  """
                  Args:
                      f_s: the feature of student network, size [batch_size, s_dim]
                      f_t: the feature of teacher network, size [batch_size, t_dim]
                      idx: the indices of these positive samples in the dataset, size [batch_size]
                      contrast_idx: the indices of negative samples, size [batch_size, nce_k]

                  Returns:
                      The contrastive loss
                  """

                  f_s = self.embed_s(f_s)
                  f_t = self.embed_t(f_t)
                  out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
                  s_loss = self.criterion_s(out_s)
                  t_loss = self.criterion_t(out_t)
                  loss = s_loss + t_loss
                  return loss

          14. Overhaul

          全稱:A Comprehensive Overhaul of Feature Distillation

          鏈接:http://openaccess.thecvf.com/content_ICCV_2019/papers/

          發(fā)表:CVPR19

          • teacher transform中提出使用margin RELU激活函數(shù)。
          • student transform中提出使用1x1卷積。

          • distillation feature postion選擇Pre-ReLU。

          • distance function部分提出了Partial L2 損失函數(shù)。

          部分實(shí)現(xiàn)如下:

          class OFD(nn.Module):
            '''
            A Comprehensive Overhaul of Feature Distillation
            http://openaccess.thecvf.com/content_ICCV_2019/papers/
            Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
            '''

            def __init__(self, in_channels, out_channels):
              super(OFD, self).__init__()
              self.connector = nn.Sequential(*[
                  nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                  nn.BatchNorm2d(out_channels)
                ])

              for m in self.modules():
                if isinstance(m, nn.Conv2d):
                  nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                  if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                  nn.init.constant_(m.weight, 1)
                  nn.init.constant_(m.bias, 0)

            def forward(self, fm_s, fm_t):
              margin = self.get_margin(fm_t)
              fm_t = torch.max(fm_t, margin)
              fm_s = self.connector(fm_s)

              mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
              loss = torch.mean((fm_s - fm_t)**2 * mask)

              return loss

            def get_margin(self, fm, eps=1e-6):
              mask = (fm < 0.0).float()
              masked_fm = fm * mask

              margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps)

              return margin

          參考文獻(xiàn)

          https://blog.csdn.net/weixin_44579633/article/details/119350631

          https://blog.csdn.net/winycg/article/details/105297089

          https://blog.csdn.net/weixin_46239293/article/details/120289163

          https://blog.csdn.net/DD_PP_JJ/article/details/121578722

          https://blog.csdn.net/DD_PP_JJ/article/details/121714957

          https://zhuanlan.zhihu.com/p/344881975

          https://blog.csdn.net/weixin_44633882/article/details/108927033

          https://blog.csdn.net/weixin_46239293/article/details/120266111

          https://blog.csdn.net/weixin_43402775/article/details/109011296

          https://blog.csdn.net/m0_37665984/article/details/103288582

          https://blog.csdn.net/m0_37665984/article/details/103269740

          好消息!

          小白學(xué)視覺(jué)知識(shí)星球

          開始面向外開放啦??????



          下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
          在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺(jué)、目標(biāo)跟蹤、生物視覺(jué)、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目52講
          小白學(xué)視覺(jué)公眾號(hào)后臺(tái)回復(fù):Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測(cè)、車道線檢測(cè)、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺(jué)實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺(jué)。

          下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
          小白學(xué)視覺(jué)公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

          交流群


          歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺(jué)、傳感器自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺(jué)SLAM“。請(qǐng)按照格式備注,否則不予通過(guò)。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~


          瀏覽 96
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲伦精品 | 四虎成人视频 | 欧美丰满老熟妇XXXXX性 精品人妻一区二区三区蜜桃 | 日日撸夜夜操 | 欧美午夜性爱视频 |