CV中的Attention機(jī)制:簡單而有效的CBAM模塊

極市導(dǎo)讀
?1. 什么是注意力機(jī)制?
通道注意力機(jī)制:對(duì)通道生成掩碼mask,進(jìn)行打分,代表是senet, Channel Attention Module 空間注意力機(jī)制:對(duì)空間進(jìn)行掩碼的生成,進(jìn)行打分,代表是Spatial Attention Module 混合域注意力機(jī)制:同時(shí)對(duì)通道注意力和空間注意力進(jìn)行評(píng)價(jià)打分,代表的有BAM, CBAM
2. CBAM模塊的實(shí)現(xiàn)
2.1 通道注意力機(jī)制

class ChannelAttention(nn.Module):def __init__(self, in_planes, rotio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.sharedMLP = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),nn.Conv2d(in_planes // rotio, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avgout = self.sharedMLP(self.avg_pool(x))maxout = self.sharedMLP(self.max_pool(x))return self.sigmoid(avgout + maxout)
2.2 空間注意力機(jī)制

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3,7), "kernel size must be 3 or 7"padding = 3 if kernel_size == 7 else 1self.conv = nn.Conv2d(2,1,kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avgout, maxout], dim=1)x = self.conv(x)return self.sigmoid(x)
2.3 Convolutional bottleneck attention module

class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.ca = ChannelAttention(planes)self.sa = SpatialAttention()self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.ca(out) * out # 廣播機(jī)制out = self.sa(out) * out # 廣播機(jī)制if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out
class cbam(nn.Module):def __init__(self, planes):self.ca = ChannelAttention(planes)# planes是feature map的通道個(gè)數(shù)self.sa = SpatialAttention()def forward(self, x):x = self.ca(out) * x # 廣播機(jī)制x = self.sa(out) * x # 廣播機(jī)制
3. 在什么情況下可以使用?

如何更有效地計(jì)算channel attention?

如何更有效地計(jì)算spatial attention?

如何組織這兩個(gè)部分?



4. 參考
推薦閱讀
與SENet互補(bǔ)提升,華為諾亞提出自注意力新機(jī)制:Weight Excitation|ECCV2020
綜述|計(jì)算機(jī)視覺中的注意力機(jī)制
論文解讀:醫(yī)學(xué)影像中的注意力機(jī)制

評(píng)論
圖片
表情
