知識(shí)蒸餾綜述:代碼整理
點(diǎn)擊上方“小白學(xué)視覺(jué)”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
編者薦語(yǔ)
收集自RepDistiller中的蒸餾方法,盡可能簡(jiǎn)單解釋蒸餾用到的策略,并提供了實(shí)現(xiàn)源碼。
全稱:Distilling the Knowledge in a Neural Network
鏈接:https://arxiv.org/pdf/1503.02531.pd3f
發(fā)表:NIPS14
最經(jīng)典的,也是明確提出知識(shí)蒸餾概念的工作,通過(guò)使用帶溫度的softmax函數(shù)來(lái)軟化教師網(wǎng)絡(luò)的邏輯層輸出作為學(xué)生網(wǎng)絡(luò)的監(jiān)督信息,
使用KL divergence來(lái)衡量學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的差異,具體流程如下圖所示(來(lái)自Knowledge Distillation A Survey)

對(duì)學(xué)生網(wǎng)絡(luò)來(lái)說(shuō),一部分監(jiān)督信息來(lái)自hard label標(biāo)簽,另一部分來(lái)自教師網(wǎng)絡(luò)提供的soft label。
代碼實(shí)現(xiàn):
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, T):
super(DistillKL, self).__init__()
self.T = T
def forward(self, y_s, y_t):
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.softmax(y_t/self.T, dim=1)
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
return loss
核心就是一個(gè)kl_div函數(shù),用于計(jì)算學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的分布差異。
2. FitNet: Hints for thin deep nets
全稱:Fitnets: hints for thin deep nets
鏈接:https://arxiv.org/pdf/1412.6550.pdf
發(fā)表:ICLR 15 Poster
對(duì)中間層進(jìn)行蒸餾的開山之作,通過(guò)將學(xué)生網(wǎng)絡(luò)的feature map擴(kuò)展到與教師網(wǎng)絡(luò)的feature map相同尺寸以后,使用均方誤差MSE Loss來(lái)衡量?jī)烧卟町悺?/p>
實(shí)現(xiàn)如下:
class HintLoss(nn.Module):
"""Fitnets: hints for thin deep nets, ICLR 2015"""
def __init__(self):
super(HintLoss, self).__init__()
self.crit = nn.MSELoss()
def forward(self, f_s, f_t):
loss = self.crit(f_s, f_t)
return loss
實(shí)現(xiàn)核心就是MSELoss
3. AT: Attention Transfer
全稱:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer
鏈接:https://arxiv.org/pdf/1612.03928.pdf
發(fā)表:ICLR16
為了提升學(xué)生模型性能提出使用注意力作為知識(shí)載體進(jìn)行遷移,文中提到了兩種注意力,一種是activation-based attention transfer,另一種是gradient-based attention transfer。實(shí)驗(yàn)發(fā)現(xiàn)第一種方法既簡(jiǎn)單效果又好。

實(shí)現(xiàn)如下:
class Attention(nn.Module):
"""Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
via Attention Transfer
code: https://github.com/szagoruyko/attention-transfer"""
def __init__(self, p=2):
super(Attention, self).__init__()
self.p = p
def forward(self, g_s, g_t):
return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def at_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
return (self.at(f_s) - self.at(f_t)).pow(2).mean()
def at(self, f):
return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
首先使用avgpool將尺寸調(diào)整一致,然后使用MSE Loss來(lái)衡量?jī)烧卟罹唷?/p>
4. SP: Similarity-Preserving
全稱:Similarity-Preserving Knowledge Distillation
鏈接:https://arxiv.org/pdf/1907.09682.pdf
發(fā)表:ICCV19
SP歸屬于基于關(guān)系的知識(shí)蒸餾方法。文章思想是提出相似性保留的知識(shí),使得教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)會(huì)對(duì)相同的樣本產(chǎn)生相似的激活。可以從下圖看出處理流程,教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)對(duì)應(yīng)feature map通過(guò)計(jì)算內(nèi)積,得到bsxbs的相似度矩陣,然后使用均方誤差來(lái)衡量?jī)蓚€(gè)相似度矩陣。

最終Loss為:
G代表的就是bsxbs的矩陣。
實(shí)現(xiàn)如下:
class Similarity(nn.Module):
"""Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
def __init__(self):
super(Similarity, self).__init__()
def forward(self, g_s, g_t):
return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def similarity_loss(self, f_s, f_t):
bsz = f_s.shape[0]
f_s = f_s.view(bsz, -1)
f_t = f_t.view(bsz, -1)
G_s = torch.mm(f_s, torch.t(f_s))
# G_s = G_s / G_s.norm(2)
G_s = torch.nn.functional.normalize(G_s)
G_t = torch.mm(f_t, torch.t(f_t))
# G_t = G_t / G_t.norm(2)
G_t = torch.nn.functional.normalize(G_t)
G_diff = G_t - G_s
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
return loss
5. CC: Correlation Congruence
全稱:Correlation Congruence for Knowledge Distillation
鏈接:https://arxiv.org/pdf/1904.01802.pdf
發(fā)表:ICCV19
CC也歸屬于基于關(guān)系的知識(shí)蒸餾方法。不應(yīng)該僅僅引導(dǎo)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)單個(gè)樣本向量之間的差異,還應(yīng)該學(xué)習(xí)兩個(gè)樣本之間的相關(guān)性,而這個(gè)相關(guān)性使用的是Correlation Congruence 教師網(wǎng)絡(luò)雨學(xué)生網(wǎng)絡(luò)相關(guān)性之間的歐氏距離。
整體Loss如下:
實(shí)現(xiàn)如下:
class Correlation(nn.Module):
"""Similarity-preserving loss. My origianl own reimplementation
based on the paper before emailing the original authors."""
def __init__(self):
super(Correlation, self).__init__()
def forward(self, f_s, f_t):
return self.similarity_loss(f_s, f_t)
def similarity_loss(self, f_s, f_t):
bsz = f_s.shape[0]
f_s = f_s.view(bsz, -1)
f_t = f_t.view(bsz, -1)
G_s = torch.mm(f_s, torch.t(f_s))
G_s = G_s / G_s.norm(2)
G_t = torch.mm(f_t, torch.t(f_t))
G_t = G_t / G_t.norm(2)
G_diff = G_t - G_s
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
return loss
6. VID: Variational Information Distillation
全稱:Variational Information Distillation for Knowledge Transfer
鏈接:https://arxiv.org/pdf/1904.05835.pdf
發(fā)表:CVPR19

利用互信息(Mutual Information)來(lái)衡量學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)差異?;バ畔⒖梢员硎境鰞蓚€(gè)變量的互相依賴程度,其值越大,表示變量之間的依賴程度越高。互信息計(jì)算如下:
互信息是教師模型的熵減去在已知學(xué)生模型條件下教師模型的熵。目標(biāo)是最大化互信息,因?yàn)榛バ畔⒃酱笳f(shuō)明H(t|s)越小,即學(xué)生網(wǎng)絡(luò)確定的情況下,教師網(wǎng)絡(luò)的熵會(huì)變小,證明學(xué)生網(wǎng)絡(luò)已經(jīng)學(xué)習(xí)的比較充分。
整體loss如下:
由于p(t|s)很難計(jì)算,可以使用變分分布q(t|s)去接近真實(shí)分布。
其中q(t|s)是使用方差可學(xué)習(xí)的高斯分布模擬(公式中的log_scale):
實(shí)現(xiàn)如下:
class VIDLoss(nn.Module):
"""Variational Information Distillation for Knowledge Transfer (CVPR 2019),
code from author: https://github.com/ssahn0215/variational-information-distillation"""
def __init__(self,
num_input_channels,
num_mid_channel,
num_target_channels,
init_pred_var=5.0,
eps=1e-5):
super(VIDLoss, self).__init__()
def conv1x1(in_channels, out_channels, stride=1):
return nn.Conv2d(
in_channels, out_channels,
kernel_size=1, padding=0,
bias=False, stride=stride)
self.regressor = nn.Sequential(
conv1x1(num_input_channels, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_target_channels),
)
self.log_scale = torch.nn.Parameter(
np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
)
self.eps = eps
def forward(self, input, target):
# pool for dimentsion match
s_H, t_H = input.shape[2], target.shape[2]
if s_H > t_H:
input = F.adaptive_avg_pool2d(input, (t_H, t_H))
elif s_H < t_H:
target = F.adaptive_avg_pool2d(target, (s_H, s_H))
else:
pass
pred_mean = self.regressor(input)
pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
pred_var = pred_var.view(1, -1, 1, 1)
neg_log_prob = 0.5*(
(pred_mean-target)**2/pred_var+torch.log(pred_var)
)
loss = torch.mean(neg_log_prob)
return loss
7. RKD: Relation Knowledge Distillation
全稱:Relational Knowledge Disitllation
鏈接:http://arxiv.org/pdf/1904.05068
發(fā)表:CVPR19
RKD也是基于關(guān)系的知識(shí)蒸餾方法,RKD提出了兩種損失函數(shù),二階的距離損失和三階的角度損失。
Distance-wise Loss
Angle-wise Loss
實(shí)現(xiàn)如下:
class RKDLoss(nn.Module):
"""Relational Knowledge Disitllation, CVPR2019"""
def __init__(self, w_d=25, w_a=50):
super(RKDLoss, self).__init__()
self.w_d = w_d
self.w_a = w_a
def forward(self, f_s, f_t):
student = f_s.view(f_s.shape[0], -1)
teacher = f_t.view(f_t.shape[0], -1)
# RKD distance loss
with torch.no_grad():
t_d = self.pdist(teacher, squared=False)
mean_td = t_d[t_d > 0].mean()
t_d = t_d / mean_td
d = self.pdist(student, squared=False)
mean_d = d[d > 0].mean()
d = d / mean_d
loss_d = F.smooth_l1_loss(d, t_d)
# RKD Angle loss
with torch.no_grad():
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
norm_td = F.normalize(td, p=2, dim=2)
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, dim=2)
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
loss_a = F.smooth_l1_loss(s_angle, t_angle)
loss = self.w_d * loss_d + self.w_a * loss_a
return loss
@staticmethod
def pdist(e, squared=False, eps=1e-12):
e_square = e.pow(2).sum(dim=1)
prod = e @ e.t()
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
if not squared:
res = res.sqrt()
res = res.clone()
res[range(len(e)), range(len(e))] = 0
return res
8. PKT:Probabilistic Knowledge Transfer
全稱:Probabilistic Knowledge Transfer for deep representation learning
鏈接:https://arxiv.org/abs/1803.10837
發(fā)表:CoRR18
提出一種概率知識(shí)轉(zhuǎn)移方法,引入了互信息來(lái)進(jìn)行建模。該方法具有可跨模態(tài)知識(shí)轉(zhuǎn)移、無(wú)需考慮任務(wù)類型、可將手工特征融入網(wǎng)絡(luò)等有點(diǎn)。

實(shí)現(xiàn)如下:
class PKT(nn.Module):
"""Probabilistic Knowledge Transfer for deep representation learning
Code from author: https://github.com/passalis/probabilistic_kt"""
def __init__(self):
super(PKT, self).__init__()
def forward(self, f_s, f_t):
return self.cosine_similarity_loss(f_s, f_t)
@staticmethod
def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
# Normalize each vector by its norm
output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
output_net = output_net / (output_net_norm + eps)
output_net[output_net != output_net] = 0
target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
target_net = target_net / (target_net_norm + eps)
target_net[target_net != target_net] = 0
# Calculate the cosine similarity
model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
target_similarity = torch.mm(target_net, target_net.transpose(0, 1))
# Scale cosine similarity to 0..1
model_similarity = (model_similarity + 1.0) / 2.0
target_similarity = (target_similarity + 1.0) / 2.0
# Transform them into probabilities
model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)
# Calculate the KL-divergence
loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))
return loss
9. AB: Activation Boundaries
全稱:Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
鏈接:https://arxiv.org/pdf/1811.03233.pdf
發(fā)表:AAAI18
目標(biāo):讓教師網(wǎng)絡(luò)層的神經(jīng)元的激活邊界盡量和學(xué)生網(wǎng)絡(luò)的一樣。所謂的激活邊界指的是分離超平面(針對(duì)的是RELU這種激活函數(shù)),其決定了神經(jīng)元的激活與失活。AB提出的激活轉(zhuǎn)移損失,讓教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)之間的分離邊界盡可能一致。

實(shí)現(xiàn)如下:
class ABLoss(nn.Module):
"""Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
code: https://github.com/bhheo/AB_distillation
"""
def __init__(self, feat_num, margin=1.0):
super(ABLoss, self).__init__()
self.w = [2**(i-feat_num+1) for i in range(feat_num)]
self.margin = margin
def forward(self, g_s, g_t):
bsz = g_s[0].shape[0]
losses = [self.criterion_alternative_l2(s, t) for s, t in zip(g_s, g_t)]
losses = [w * l for w, l in zip(self.w, losses)]
# loss = sum(losses) / bsz
# loss = loss / 1000 * 3
losses = [l / bsz for l in losses]
losses = [l / 1000 * 3 for l in losses]
return losses
def criterion_alternative_l2(self, source, target):
loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() +
(source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float())
return torch.abs(loss).sum()
10. FT: Factor Transfer
全稱:Paraphrasing Complex Network: Network Compression via Factor Transfer
鏈接:https://arxiv.org/pdf/1802.04977.pdf
發(fā)表:NIPS18
提出的是factor transfer的方法。所謂的factor,其實(shí)是對(duì)模型最后的數(shù)據(jù)結(jié)果進(jìn)行一個(gè)編解碼的過(guò)程,提取出的一個(gè)factor矩陣,用教師網(wǎng)絡(luò)的factor來(lái)指導(dǎo)學(xué)生網(wǎng)絡(luò)的factor。

FT計(jì)算公式為:
實(shí)現(xiàn)如下:
class FactorTransfer(nn.Module):
"""Paraphrasing Complex Network: Network Compression via Factor Transfer, NeurIPS 2018"""
def __init__(self, p1=2, p2=1):
super(FactorTransfer, self).__init__()
self.p1 = p1
self.p2 = p2
def forward(self, f_s, f_t):
return self.factor_loss(f_s, f_t)
def factor_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
if self.p2 == 1:
return (self.factor(f_s) - self.factor(f_t)).abs().mean()
else:
return (self.factor(f_s) - self.factor(f_t)).pow(self.p2).mean()
def factor(self, f):
return F.normalize(f.pow(self.p1).mean(1).view(f.size(0), -1))
11. FSP: Flow of Solution Procedure
全稱:A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
鏈接:https://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
發(fā)表:CVPR17
FSP認(rèn)為教學(xué)生網(wǎng)絡(luò)不同層輸出的feature之間的關(guān)系比教學(xué)生網(wǎng)絡(luò)結(jié)果好

定義了FSP矩陣來(lái)定義網(wǎng)絡(luò)內(nèi)部特征層之間的關(guān)系,是一個(gè)Gram矩陣反映老師教學(xué)生的過(guò)程。

使用的是L2 Loss進(jìn)行約束FSP矩陣。
實(shí)現(xiàn)如下:
class FSP(nn.Module):
"""A Gift from Knowledge Distillation:
Fast Optimization, Network Minimization and Transfer Learning"""
def __init__(self, s_shapes, t_shapes):
super(FSP, self).__init__()
assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
s_c = [s[1] for s in s_shapes]
t_c = [t[1] for t in t_shapes]
if np.any(np.asarray(s_c) != np.asarray(t_c)):
raise ValueError('num of channels not equal (error in FSP)')
def forward(self, g_s, g_t):
s_fsp = self.compute_fsp(g_s)
t_fsp = self.compute_fsp(g_t)
loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
return loss_group
@staticmethod
def compute_loss(s, t):
return (s - t).pow(2).mean()
@staticmethod
def compute_fsp(g):
fsp_list = []
for i in range(len(g) - 1):
bot, top = g[i], g[i + 1]
b_H, t_H = bot.shape[2], top.shape[2]
if b_H > t_H:
bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
elif b_H < t_H:
top = F.adaptive_avg_pool2d(top, (b_H, b_H))
else:
pass
bot = bot.unsqueeze(1)
top = top.unsqueeze(2)
bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)
top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)
fsp = (bot * top).mean(-1)
fsp_list.append(fsp)
return fsp_list
12. NST: Neuron Selectivity Transfer
全稱:Like what you like: knowledge distill via neuron selectivity transfer
鏈接:https://arxiv.org/pdf/1707.01219.pdf
發(fā)表:CoRR17
使用新的損失函數(shù)最小化教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)之間的Maximum Mean Discrepancy(MMD), 文中選擇的是對(duì)其教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)之間神經(jīng)元選擇樣式的分布。

使用核技巧(對(duì)應(yīng)下面poly kernel)并進(jìn)一步展開以后可得:
實(shí)際上提供了Linear Kernel、Poly Kernel、Gaussian Kernel三種,這里實(shí)現(xiàn)只給了Poly這種,這是因?yàn)镻oly這種方法可以與KD進(jìn)行互補(bǔ),這樣整體效果會(huì)非常好。
實(shí)現(xiàn)如下:
class NSTLoss(nn.Module):
"""like what you like: knowledge distill via neuron selectivity transfer"""
def __init__(self):
super(NSTLoss, self).__init__()
pass
def forward(self, g_s, g_t):
return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def nst_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
f_s = F.normalize(f_s, dim=2)
f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
f_t = F.normalize(f_t, dim=2)
# set full_loss as False to avoid unnecessary computation
full_loss = True
if full_loss:
return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
- 2 * self.poly_kernel(f_s, f_t).mean())
else:
return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()
def poly_kernel(self, a, b):
a = a.unsqueeze(1)
b = b.unsqueeze(2)
res = (a * b).sum(-1).pow(2)
return res
13. CRD: Contrastive Representation Distillation
全稱:Contrastive Representation Distillation
鏈接:https://arxiv.org/abs/1910.10699v2
發(fā)表:ICLR20
將對(duì)比學(xué)習(xí)引入知識(shí)蒸餾中,其目標(biāo)修正為:學(xué)習(xí)一個(gè)表征,讓正樣本對(duì)的教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)盡可能接近,負(fù)樣本對(duì)教師網(wǎng)絡(luò)與學(xué)生網(wǎng)絡(luò)盡可能遠(yuǎn)離。
構(gòu)建的對(duì)比學(xué)習(xí)問(wèn)題表示如下:
整體的蒸餾Loss表示如下:
實(shí)現(xiàn)如下:https://github.com/HobbitLong/RepDistiller
class ContrastLoss(nn.Module):
"""
contrastive loss, corresponding to Eq (18)
"""
def __init__(self, n_data):
super(ContrastLoss, self).__init__()
self.n_data = n_data
def forward(self, x):
bsz = x.shape[0]
m = x.size(1) - 1
# noise distribution
Pn = 1 / float(self.n_data)
# loss for positive pair
P_pos = x.select(1, 0)
log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
# loss for K negative pair
P_neg = x.narrow(1, 1, m)
log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
return loss
class CRDLoss(nn.Module):
"""CRD Loss function
includes two symmetric parts:
(a) using teacher as anchor, choose positive and negatives over the student side
(b) using student as anchor, choose positive and negatives over the teacher side
Args:
opt.s_dim: the dimension of student's feature
opt.t_dim: the dimension of teacher's feature
opt.feat_dim: the dimension of the projection space
opt.nce_k: number of negatives paired with each positive
opt.nce_t: the temperature
opt.nce_m: the momentum for updating the memory buffer
opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
"""
def __init__(self, opt):
super(CRDLoss, self).__init__()
self.embed_s = Embed(opt.s_dim, opt.feat_dim)
self.embed_t = Embed(opt.t_dim, opt.feat_dim)
self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
self.criterion_t = ContrastLoss(opt.n_data)
self.criterion_s = ContrastLoss(opt.n_data)
def forward(self, f_s, f_t, idx, contrast_idx=None):
"""
Args:
f_s: the feature of student network, size [batch_size, s_dim]
f_t: the feature of teacher network, size [batch_size, t_dim]
idx: the indices of these positive samples in the dataset, size [batch_size]
contrast_idx: the indices of negative samples, size [batch_size, nce_k]
Returns:
The contrastive loss
"""
f_s = self.embed_s(f_s)
f_t = self.embed_t(f_t)
out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
s_loss = self.criterion_s(out_s)
t_loss = self.criterion_t(out_t)
loss = s_loss + t_loss
return loss
14. Overhaul
全稱:A Comprehensive Overhaul of Feature Distillation
鏈接:http://openaccess.thecvf.com/content_ICCV_2019/papers/
發(fā)表:CVPR19
teacher transform中提出使用margin RELU激活函數(shù)。

student transform中提出使用1x1卷積。
distillation feature postion選擇Pre-ReLU。

distance function部分提出了Partial L2 損失函數(shù)。

部分實(shí)現(xiàn)如下:
class OFD(nn.Module):
'''
A Comprehensive Overhaul of Feature Distillation
http://openaccess.thecvf.com/content_ICCV_2019/papers/
Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
'''
def __init__(self, in_channels, out_channels):
super(OFD, self).__init__()
self.connector = nn.Sequential(*[
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels)
])
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, fm_s, fm_t):
margin = self.get_margin(fm_t)
fm_t = torch.max(fm_t, margin)
fm_s = self.connector(fm_s)
mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
loss = torch.mean((fm_s - fm_t)**2 * mask)
return loss
def get_margin(self, fm, eps=1e-6):
mask = (fm < 0.0).float()
masked_fm = fm * mask
margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps)
return margin
參考文獻(xiàn)
https://blog.csdn.net/weixin_44579633/article/details/119350631
https://blog.csdn.net/winycg/article/details/105297089
https://blog.csdn.net/weixin_46239293/article/details/120289163
https://blog.csdn.net/DD_PP_JJ/article/details/121578722
https://blog.csdn.net/DD_PP_JJ/article/details/121714957
https://zhuanlan.zhihu.com/p/344881975
https://blog.csdn.net/weixin_44633882/article/details/108927033
https://blog.csdn.net/weixin_46239293/article/details/120266111
https://blog.csdn.net/weixin_43402775/article/details/109011296
https://blog.csdn.net/m0_37665984/article/details/103288582
https://blog.csdn.net/m0_37665984/article/details/103269740
好消息!
小白學(xué)視覺(jué)知識(shí)星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺(jué)、目標(biāo)跟蹤、生物視覺(jué)、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測(cè)、車道線檢測(cè)、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺(jué)實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺(jué)。 下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。 交流群
歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺(jué)、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺(jué)SLAM“。請(qǐng)按照格式備注,否則不予通過(guò)。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~

