UNet 和 UNet++:醫(yī)學(xué)影像經(jīng)典分割網(wǎng)絡(luò)對比

來源:極市平臺 本文約3000字,建議閱讀5分鐘
本文介紹了醫(yī)學(xué)影像經(jīng)典分割網(wǎng)絡(luò)的對比。




根據(jù)論文, Unet++的表現(xiàn)似乎優(yōu)于原來的Unet。就像在Unet中一樣,這里可以使用多個編碼器(骨干)來為輸入圖像生成強(qiáng)特征。
我應(yīng)該使用哪個編碼器?

定義數(shù)據(jù)集和增強(qiáng)。我們將調(diào)整圖像大小為256×256,并對訓(xùn)練數(shù)據(jù)集應(yīng)用一些大的增強(qiáng)。
import?albumentations?as?A
from?torch.utils.data?import?Dataset,?DataLoader
from?collections?import?OrderedDict
class?ChestXRayDataset(Dataset):
????def?__init__(
????????self,
????????images,
????????masks,
????????????transforms):
????????self.images?=?images
????????self.masks?=?masks
????????self.transforms?=?transforms
????def?__len__(self):
????????return(len(self.images))
????def?__getitem__(self,?idx):
????????"""Will?load?the?mask,?get?random?coordinates?around/with?the?mask,
????????load?the?image?by?coordinates
????????"""
????????sample_image?=?imread(self.images[idx])
????????if?len(sample_image.shape)?==?3:
????????????sample_image?=?sample_image[...,?0]
????????sample_image?=?np.expand_dims(sample_image,?2)?/?255
????????sample_mask?=?imread(self.masks[idx])?/?255
????????if?len(sample_mask.shape)?==?3:
????????????sample_mask?=?sample_mask[...,?0]??
????????augmented?=?self.transforms(image=sample_image,?mask=sample_mask)
????????sample_image?=?augmented['image']
????????sample_mask?=?augmented['mask']
????????sample_image?=?sample_image.transpose(2,?0,?1)??#?channels?first
????????sample_mask?=?np.expand_dims(sample_mask,?0)
????????data?=?{'features':?torch.from_numpy(sample_image.copy()).float(),
????????????????'mask':?torch.from_numpy(sample_mask.copy()).float()}
????????return(data)
????
def?get_valid_transforms(crop_size=256):
????return?A.Compose(
????????[
????????????A.Resize(crop_size,?crop_size),
????????],
????????p=1.0)
def?light_training_transforms(crop_size=256):
????return?A.Compose([
????????A.RandomResizedCrop(height=crop_size,?width=crop_size),
????????A.OneOf(
????????????[
????????????????A.Transpose(),
????????????????A.VerticalFlip(),
????????????????A.HorizontalFlip(),
????????????????A.RandomRotate90(),
????????????????A.NoOp()
????????????],?p=1.0),
????])
def?medium_training_transforms(crop_size=256):
????return?A.Compose([
????????A.RandomResizedCrop(height=crop_size,?width=crop_size),
????????A.OneOf(
????????????[
????????????????A.Transpose(),
????????????????A.VerticalFlip(),
????????????????A.HorizontalFlip(),
????????????????A.RandomRotate90(),
????????????????A.NoOp()
????????????],?p=1.0),
????????A.OneOf(
????????????[
????????????????A.CoarseDropout(max_holes=16,?max_height=16,?max_width=16),
????????????????A.NoOp()
????????????],?p=1.0),
????])
def?heavy_training_transforms(crop_size=256):
????return?A.Compose([
????????A.RandomResizedCrop(height=crop_size,?width=crop_size),
????????A.OneOf(
????????????[
????????????????A.Transpose(),
????????????????A.VerticalFlip(),
????????????????A.HorizontalFlip(),
????????????????A.RandomRotate90(),
????????????????A.NoOp()
????????????],?p=1.0),
????????A.ShiftScaleRotate(p=0.75),
????????A.OneOf(
????????????[
????????????????A.CoarseDropout(max_holes=16,?max_height=16,?max_width=16),
????????????????A.NoOp()
????????????],?p=1.0),
????])
def?get_training_trasnforms(transforms_type):
????if?transforms_type?==?'light':
????????return(light_training_transforms())
????elif?transforms_type?==?'medium':
????????return(medium_training_transforms())
????elif?transforms_type?==?'heavy':
????????return(heavy_training_transforms())
????else:
????????raise?NotImplementedError("Not?implemented?transformation?configuration")
定義模型和損失函數(shù)。這里我們使用帶有regnety\_004編碼器的Unet++,并使用RAdam + Lookahed優(yōu)化器使用DICE + BCE損失之和進(jìn)行訓(xùn)練。
import?torch
import?segmentation_models_pytorch?as?smp
import?numpy?as?np
import?matplotlib.pyplot?as?plt
from?catalyst?import?dl,?metrics,?core,?contrib,?utils
import?torch.nn?as?nn
from?skimage.io?import?imread
import?os
from?sklearn.model_selection?import?train_test_split
from?catalyst.dl?import??CriterionCallback,?MetricAggregationCallback
encoder?=?'timm-regnety_004'
model?=?smp.UnetPlusPlus(encoder,?classes=1,?in_channels=1)
#model.cuda()
learning_rate?=?5e-3
encoder_learning_rate?=?5e-3?/?10
layerwise_params?=?{"encoder*":?dict(lr=encoder_learning_rate,?weight_decay=0.00003)}
model_params?=?utils.process_model_params(model,?layerwise_params=layerwise_params)
base_optimizer?=?contrib.nn.RAdam(model_params,?lr=learning_rate,?weight_decay=0.0003)
optimizer?=?contrib.nn.Lookahead(base_optimizer)
scheduler?=?torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,?factor=0.25,?patience=10)
criterion?=?{
????"dice":?DiceLoss(mode='binary'),
????"bce":?nn.BCEWithLogitsLoss()
}
定義回調(diào)函數(shù)并訓(xùn)練!
callbacks?=?[
????#?Each?criterion?is?calculated?separately.
????CriterionCallback(
???????input_key="mask",
????????prefix="loss_dice",
????????criterion_key="dice"
????),
????CriterionCallback(
????????input_key="mask",
????????prefix="loss_bce",
????????criterion_key="bce"
????),
????#?And?only?then?we?aggregate?everything?into?one?loss.
????MetricAggregationCallback(
????????prefix="loss",
????????mode="weighted_sum",?
????????metrics={
????????????"loss_dice":?1.0,?
????????????"loss_bce":?0.8
????????},
????),
????#?metrics
????IoUMetricsCallback(
????????mode='binary',?
????????input_key='mask',?
????)
????
]
runner?=?dl.SupervisedRunner(input_key="features",?input_target_key="mask")
runner.train(
????model=model,
????criterion=criterion,
????optimizer=optimizer,
????scheduler=scheduler,
????loaders=loaders,
????callbacks=callbacks,
????logdir='../logs/xray_test_log',
????num_epochs=100,
????main_metric="loss",
????minimize_metric=True,
????verbose=True,
)





總結(jié)
英文原文:
https://towardsdatascience.com/the-best-approach-to-semantic-segmentation-of-biomedical-images-bbe4fd78733f
評論
圖片
表情
