分割模型,损失函数,训练流程...
一、网络模型
常见分割网络
1.U-Net 3d
// U-Net3d
class UNet3D(nn.Module):
def __init__(self, params):
super(UNet3D, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.trilinear = self.params['trilinear']
self.dropout = self.params['dropout']
assert (len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
self.in_conv = ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
if (len(self.ft_chns) == 5):
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3],
dropout_p=0.0, trilinear=self.trilinear)
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2],
dropout_p=0.0, trilinear=self.trilinear)
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1],
dropout_p=0.0, trilinear=self.trilinear)
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0],
dropout_p=0.0, trilinear=self.trilinear)
self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class,
kernel_size=3, padding=1) #bias = -2.19
if self.params['activation'] == 'sigmoid':
self.activation = nn.Sigmoid()
elif self.params['activation'] == 'softmax':
self.activation = nn.Softmax()
else:
self.activation = 'None'
def forward(self, x):
x0 = self.in_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
if (len(self.ft_chns) == 5):
x4 = self.down4(x3)
x = self.up1(x4, x3)
else:
x = x3
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.up4(x, x0)
output = self.out_conv(x)
if self.activation == 'None':
return output
output_sigmoid = self.activation(output)
return output_sigmoid
2.U-Net 2d
// parts of unet
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
""" Parts of the U-Net model """
"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
// full unet
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from model.parts_unet import *
""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
self.up1 = Up(1024, 256, bilinear)
self.up2 = Up(512, 128, bilinear)
self.up3 = Up(256, 64, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
if __name__ == '__main__':
net = UNet(n_channels=3, n_classes=1)
print(net)
二、损失函数
1.Focal loss
def FocalLoss(predict, soft_y, softmax):
alpha_num = 0.01
gamma = 2
with torch.no_grad():
alpha = torch.empty_like(predict).fill_(1 - alpha_num)
alpha[soft_y == 1] = alpha_num
crit = nn.BCELoss(reduction='none')
ce_loss = crit(predict.float(), soft_y.float())
pt = torch.exp(-ce_loss)
loss = (alpha * torch.pow(1 - pt, gamma) * ce_loss)
loss = loss.mean()
return loss
2.Hausdoff loss
def hd_loss(seg, gt, seg_dtm, gt_dtm):
"""
compute Hausdorff distance loss for binary segmentation based on distance transform
:param seg: seg results, shape=(b,c,d,h,w) ,for binary c = 1
:param gt: ground truth, shape=(b,c,d,h,w), for binary c = 1
:param seg_dtm: segmentation distance transform, shape=(b,c,d,h,w)
:param gt_dtm: ground truth distance transform, shape=(b,c,d,h,w)
:return: boundary Hausdorff distance
"""
delta = (seg - gt) ** 2
seg_dtm_alpha = seg_dtm ** 2
gt_dtm_alpha = gt_dtm ** 2
dtm = seg_dtm_alpha + gt_dtm_alpha
multiple_d = torch.einsum('bcxyz, bcxyz->bcxyz', dtm, delta)
hd_loss_value = multiple_d.mean()
return hd_loss_value
3.Dice loss
def dice_loss(predict, soft_y, softmax=False):
smooth = 1e-5
num = predict.size(0)
p_vol = predict.view(num, -1)
y_vol = soft_y.view(num, -1)
intersection = (p_vol * y_vol).sum(1)
dice_score = (2. * intersection + smooth) / (p_vol.sum(1) + y_vol.sum(1) + smooth)
return 1 - dice_score.sum() / num
三、训练一个U-Net
1.数据加载
2.模型选择,这里为U-Net
3.损失函数
4.训练
5.预测
//dataset
# !/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
from torch.utils.data import Dataset
import glob
import cv2
import random
import os
'''data loader'''
# inherit class Dataset():
class DataLoader(Dataset):
def __init__(self, data_path):
# initialize file paths or a list of file names
self.data_path = data_path
self.imgs_path = glob.glob(os.path.join(data_path, 'image/*png'))
def augment(self, image, flipCode):
# horizontal flip : flipCode = 1; vertical:=0; h & v : =-1
flip = cv2.flip(image, flipCode)
return flip
def __getitem__(self, index):
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
image_path = self.imgs_path[index]
label_path = image_path.replace('image', 'label')
image = cv2.imread(image_path)
label = cv2.imread(label_path)
image = image.reshape(1, image.shape[0], image.shape[1])
label = label.reshape(1, label.shape[0], label.shape[1])
if label.max() > 1:
label = label / 255
# random flip the image & label if flipCode != 2
flipCode = random.chioce([-1, 0, 1, 2])
if flipCode != 2:
image = self.augment(image, flipCode)
label = self.augment(label, flipCode)
return image, label
def __len__(self):
# get the len of train set
return len(self.imgs_path)
if __name__ == '__main__':
liunj_dataset = DataLoader('data/train/')
print('number of train data:', len(liunj_dataset))
train_loader = torch.utils.data.DataLoader(dataset=DataLoader,
batch_size=2,
shuffle=True)
for image, label in train_loader:
print(image.shape)
//train_model
# -*- coding: utf-8 -*-
from model.full_unet import UNet
from dataset import DataLoader
from torch import optim
import torch.nn as nn
import torch
import time
import matplotlib.pyplot as plt
def train(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
start = time.time()
liunj_dataset = DataLoader(data_path)
train_loader = torch.utils.data.DataLoader(dataset=liunj_dataset,
batch_size=batch_size,
shuffle=True)
# define the optimizer
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8,momentum=0.9)
# define the loss function
criterion = nn.BCEWithLogitsLoss()
# initial the loss +wuqiong
best_loss = float('inf')
for epoch in range(epochs):
print('Epoch {}/{}'.format(epoch + 1, epochs))
print('-' * 50)
# train model
net.train()
for image, label in train_loader:
optimizer.zero_grad()
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
pred = net(image)
loss = criterion(pred, label)
print('Loss/train', loss.item())
if loss < best_loss:
best_loss = loss
torch.save(net.state_dict(), 'best_model.pth')
loss.backward()
optimizer.step()
# if step % 10 == 0:
# print('Epoch {}/{} | Current step: {} | Loss: {} | Acc: {} | AllocMem (Mb): {}' \
# .format(epoch + 1, epochs, step, loss, acc, torch.cuda.memory_allocated()/1024/1024)
# )
# # current step: {} / {}
# epoch_loss = running_loss / len(dataloader.dataset)
# epoch_acc = running_acc / len(dataloader.dataset)
#
# train_loss.append(epoch_loss) if phase == 'Train' else valid_loss.append(epoch_loss)
#
# time_elapsed = time.time() - start
# print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
# def acc_metric(predb, yb):
# return (predb.argmax(dim=1) == yb.cuba()).float().mean()
if __name__ == "__main_":
device = torch.device('cuda' if torch.cuda.is_availabel() else 'cpu')
net = UNet(n_channels=3, n_classes=3)
net.to(device=device)
data_path = 'data/train/'
train(net, device, data_path)
# 此段存在问题,后续可以使用tensorboard可视化训练过程
plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.plot(loss[:])
plt.title("train_loss")
plt.subplot(122)
plt.plot(train_epochs_loss[1:], '-o', label="train_loss")
plt.plot(valid_epochs_loss[1:], '-o', label="valid_loss")
plt.title("epochs_loss")
plt.legend()
plt.show()
# -*- coding: utf-8 -*-
import glob
import numpy as np
import torch
import cv2
from model.full_unet import UNet
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(n_channels=3, n_classes=2)
net.to(device=device)
net.load_state_dict(torch.load('best_model.pth', map_location=device))
# test model
net.eval()
# get all the images: '../../...png'
test_path = glob.glob('data/test/*.png')
for test_path in test_path:
save_res_path = test_path.split('.')[0] + '_res.png'
img = cv2.imread(test_path)
img = cv2.reshape(1, 1, img.shape[0], img.shape[1])
# change to tensor
img_tensor = torch.from_numpy(img)
# copy the img_tensor to device
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
pred = net(img_tensor)
pred = np.array(pred.data.cpu()[0][0])
pred[pred >= 0.5] = 255
pred[pred < 0.5] = 0
cv2.imwrite(save_res_path, pred)
版权声明:本文为biejieyu1016原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。