1
preface
This article belongs to a series of tutorials on semantic segmentation for pytoch in-depth learning.
The contents of this series of articles are:
Basic use of pytoch
Explanation of semantic segmentation algorithm
Since wechat does not allow external links, you need to click "read the original text" in the lower left corner of the footer to access the links in the text. All external links in the text have been marked in blue font.
2
Project background
Deep learning algorithm is nothing more than our method to solve a problem. What kind of network to train, what kind of preprocessing to carry out, and what Loss and optimization methods to adopt are all determined according to the specific task.
So let's take a look at today's task first.
Yes, it is the classic task in UNet paper: medical image segmentation.
Choose it as today's task because it is simple and easy to use.
Briefly describe this task: as shown in the dynamic diagram, give a cell structure diagram, and we will separate each cell from each other.
There are only 30 pieces of training data with a resolution of 512x512. These pictures are electron micrographs of fruit flies.
All right, after the task introduction, start to prepare the training model.
3
UNet training
To train a deep learning model, you can simply divide it into three steps:
Data loading: how to load data, how to define labels, and what data enhancement methods to use are all carried out in this step.
Model selection: we have prepared the model, which is the UNet network mentioned in the last article of this series.
Algorithm selection: algorithm selection is what loss we choose and what optimization algorithm we use.
Each step is relatively general. We will explain it in combination with today's medical image segmentation task.
1. Data loading
In this step, many things can be done. To put it bluntly, it is nothing more than how to load the picture and how to define the label. In order to increase the robustness of the algorithm or increase the data set, some data enhancement operations can be done.
Since we are dealing with data, let's first look at what the data is like, and then decide how to deal with it.
The data is ready and here it is: (click the original link to learn more)
https://github.com/Jack-Cherish/Deep-Learning/tree/master/Pytorch-Seg/lesson-2/data
The data is divided into training set and test set, with 30 pieces each. The training set has labels and the test set has no labels.
The processing of data loading depends on the task and data set. For our segmentation task, we don't need to do much processing. However, due to the small amount of data, only 30 pieces, we can use some data enhancement methods to expand our data set.
Pytorch provides us with a way to load data. We can use this framework to load our data. Look at the pseudo code:
# ================================================================== # # Input pipeline for custom dataset # # ================================================================== # # You should build your custom dataset as below. class CustomDataset(torch.utils.data.Dataset): def __init__(self): # TODO # 1. Initialize file paths or a list of file names. pass def __getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). pass def __len__(self): # You should change 0 to the total size of your dataset. return 0 # You can then use the prebuilt data loader. custom_dataset = CustomDataset() train_loader = torch.utils.data.DataLoader(dataset=custom_dataset, batch_size= 64, shuffle= True)
This is a standard template. We use this template to load data, define labels, and enhance data.
Create a dataset Py file. The code is as follows:
import torch import cv2 import os import glob from torch.utils.data import Dataset import random class ISBI_Loader(Dataset): def __init__(self, data_path): # Initialize the function and read all data_ Picture under path self.data_path = data_path self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png')) def augment(self, image, flipCode): # Use CV2 Flip for data enhancement. filpCode is 1 horizontal flip, 0 vertical flip, - 1 horizontal + vertical flip flip = cv2.flip(image, flipCode) return flip def __getitem__(self, index): # Read pictures according to index image_path = self.imgs_path[index] # According to image_ Generate label from path_ path label_path = image_path.replace( 'image', 'label') # Read training pictures and label pictures image = cv2.imread(image_path) label = cv2.imread(label_path) # Convert data to a single channel picture image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY) image = image.reshape( 1, image.shape[ 0], image.shape[ 1]) label = label.reshape( 1, label.shape[ 0], label.shape[ 1]) # Process the label and change the pixel value of 255 to 1 if label.max() > 1: label = label / 255 # Random data enhancement, no processing when it is 2 flipCode = random.choice([ -1, 0, 1, 2]) if flipCode != 2: image = self.augment(image, flipCode) label = self.augment(label, flipCode) return image, label def __len__(self): # Returns the size of the training set return len(self.imgs_path) if __name__ == "__main__": isbi_dataset = ISBI_Loader( "data/train/") print( "Number of data:", len(isbi_dataset)) train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset, batch_size= 2, shuffle= True) for image, label in train_loader: print(image.shape)
Run the code and you can see the following results:
Explain the code:
__ init__ Function is the initialization function of this class. It reads all picture data according to the specified picture path and stores it in self imgs_ Path list.
__ len__ How much data can the function return? After instantiation of this class, it is called through len() function.
__ getitem__ Function is a data acquisition function. In this function, you can write, read and process data, and some data preprocessing and data enhancement can be carried out here. My processing here is very simple. I just read the picture and process it into a single channel picture. At the same time, because the picture pixels of label are 0 and 255, it needs to be divided by 255 to become 0 and 1. At the same time, random data enhancement is carried out.
The augment function is a defined data enhancement function, which can be processed in any way. I just performed a simple rotation operation here.
In this class, you don't have to do some operations that disrupt the data set, and you don't have to worry about how to read the data according to the batch size. Because after instantiating this class, we can use torch utils. data. The dataloader method specifies the size of the batchsize and determines whether to disrupt the data.
The DataLoader provided by Pytorch is very powerful. We can even specify how many processes are used to load data and whether the data is loaded into CUDA memory. It is not covered in this article, so we won't explain it.
2. Model selection
The model we have chosen is # UNet network structure.
However, we need to fine tune the network. According to the structure of the paper, the size of the model output will be slightly smaller than the size of the image input. If we use the network structure of the paper, we need to do a resize operation after the result output. In order to save this step, we can modify the network so that the output size of the network is exactly equal to the input size of the picture.
Create unet_parts.py file, write the following code:
""" Parts of the U-Net model """ """https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py""" import torch import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels): super().__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size= 3, padding= 1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace= True), nn.Conv2d(out_channels, out_channels, kernel_size= 3, padding= 1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace= True) ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d( 2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor= 2, mode= 'bilinear', align_corners= True) else: self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size= 2, stride= 2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = torch.tensor([x2.size()[ 2] - x1.size()[ 2]]) diffX = torch.tensor([x2.size()[ 3] - x1.size()[ 3]]) x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim= 1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size= 1) def forward(self, x): return self.conv(x)
Create unet_model.py file, write the following code:
"" " Full assembly of the parts to form the complete network " "" "" "Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py" "" import torch.nn.functional as F from .unet_parts import * class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=True): super(UNet, self).__init_ _() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down( 64, 128) self.down2 = Down( 128, 256) self.down3 = Down( 256, 512) self.down4 = Down( 512, 512) self.up1 = Up( 1024, 256, bilinear) self.up2 = Up( 512, 128, bilinear) self.up3 = Up( 256, 64, bilinear) self.up4 = Up( 128, 64, bilinear) self.outc = OutConv( 64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits if __name_ _ == '__main__': net = UNet(n_channels= 3, n_classes= 1) print(net)
After this adjustment, the output size of the network is the same as the input size of the picture.
3. Algorithm selection
It is very important to choose what Loss is. The quality of Loss selection will affect the effect of the algorithm fitting data.
What Loss to choose is also determined according to the task. Our task today only needs to segment the cell edge, which is a very simple binary classification task, so we can use BCEWithLogitsLoss.
What is BCEWithLogitsLoss? BCEWithLogitsLoss is a function provided by Pytorch to calculate the cross entropy of binary classification.
Its formula is:
Friends who have read my machine learning series tutorials must be familiar with this formula. It is the loss function of Logistic regression. It uses the Sigmoid function threshold in [0,1] to classify.
When the objective function, namely Loss, is determined, how to optimize this objective?
The simplest way is to gradually approach the local extreme value with the familiar gradient descent algorithm.
However, this simple optimization algorithm is slow to solve, that is, it is hard to find the optimal solution.
Various optimization algorithms are gradient descent in essence. For example, the most conventional SGD is the improved random gradient descent algorithm based on gradient descent, and Momentum is the SGD with Momentum introduced to accumulate the historical gradient in the form of exponential attenuation.
In addition to these most basic optimization algorithms, there are also adaptive parameter optimization algorithms. The biggest feature of this kind of algorithm is that each parameter has different learning rates, which can automatically adapt to these learning rates in the whole learning process, so as to achieve better convergence effect.
This paper chooses an adaptive optimization algorithm RMSProp.
Due to the limited space, we will not expand here. It is not enough to write a single article on this optimization algorithm. To understand RMSProp, you must first know what is AdaGrad, because RMSProp is based on the improvement of AdaGrad.
There are also more advanced optimization algorithms than RMSProp, such as Adam, which can be regarded as the modified Momentum+RMSProp algorithm.
In short, for beginners, you only need to know that RMSProp is an adaptive optimization algorithm, which is relatively advanced.
Next, we can start to write the code for training UNet and create train Py write the following code:
from model.unet_model import UNet from utils.dataset import ISBI_Loader from torch import optim import torch.nn as nn import torch def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001): # Load training set isbi_dataset = ISBI_Loader(data_path) train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset, batch_size=batch_size, shuffle= True) # Define RMSprop algorithm optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay= 1e-8, momentum= 0.9) # Define Loss algorithm criterion = nn.BCEWithLogitsLoss() # best_loss statistics, initialized to positive infinity best_loss = float( 'inf') # epochs training for epoch in range(epochs): # Training mode net.train() # According to batch_ Start training for image, label in train_loader: optimizer.zero_grad() # Copy data to device image = image.to(device=device, dtype=torch.float32) label = label.to(device=device, dtype=torch.float32) # Use the network parameters to output the prediction results pred = net(image) # Calculate loss loss = criterion(pred, label) print( 'Loss/train', loss.item()) # Save the network parameters with the lowest loss value if loss < best_loss: best_loss = loss torch.save(net.state_dict(), 'best_model.pth') # Update parameters loss.backward() optimizer.step() if __name__ == "__main__": # Select the device, cuda with cuda, cpu without cuda device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # Load network, picture, single channel 1, classified as 1. net = UNet(n_channels= 1, n_classes= 1) # Copy the network to deivce net.to(device=device) # Specify the training set address to start training data_path = "data/train/" train_net(net, device, data_path)
In order to make the project more clear and concise, we create a model folder and put the code related to the model, that is, our network structure code, unet_parts.py and unet_model.py.
Create a utils folder and put the code related to the tool, such as the data loading tool dataset py.
This modular management greatly improves the maintainability of the code.
train.py can be placed in the root directory of the project. Simply explain the code.
Since there are only 30 pieces of data, we do not distinguish between the training set and the verification set. We save the network parameter with the lowest loss value in the training set as the best model parameter.
If there are no problems, you can see that loss is gradually converging.
4
forecast
After the model is trained, we can use it to see the effect on the test set.
Create predict. In the project root directory Py file, write the following code:
import glob import numpy as np import torch import os import cv2 from model.unet_model import UNet if __name__ == "__main__": # Select the device, cuda with cuda, cpu without cuda device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # Load network, picture, single channel, classified as 1. net = UNet(n_channels= 1, n_classes= 1) # Copy the network to deivce net.to(device=device) # Load model parameters net.load_state_dict(torch.load( 'best_model.pth', map_location=device)) # Test mode net.eval() # Read all picture paths tests_path = glob.glob( 'data/test/*.png') #Traverse all pictures for test_path in tests_path: # Save result address save_res_path = test_path.split( '.')[ 0] + '_res.png' # Read picture img = cv2.imread(test_path) # Convert to grayscale img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Convert to an array with batch = 1, channel = 1 and size of 512 * 512 img = img.reshape( 1, 1, img.shape[ 0], img.shape[ 1]) # Turn to tensor img_tensor = torch.from_numpy(img) # Copy the tensor to the device. Only using cpu means copying to cpu, and cuda means copying to cuda. img_tensor = img_tensor.to(device=device, dtype=torch.float32) # forecast pred = net(img_tensor) # Extraction results pred = np.array(pred.data.cpu()[ 0])[ 0] # Processing results pred[pred >= 0.5] = 255 pred[pred < 0.5] = 0 # Save picture cv2.imwrite(save_res_path, pred)
After running, you can see the prediction results in the data/test directory:
be accomplished!
5
last
This paper mainly explains the three steps of training model: data loading, model selection and algorithm selection.
This is a simple example. Training normal visual tasks is much more complex. For example, when training a model, you need to choose which model to save according to the accuracy of the model in the verification set; We need to support tensorboard to facilitate us to observe the convergence of loss and so on.
PS: if you think this chapter is helpful to you, please pay attention, comment and praise!
More wonderful content (please click the picture to read)
Official account: AI snail car
Maintain humility, self-discipline and progress
Personal wechat
Note: nickname + school / company + direction
If there is no note, don't pull the group!
Pull you into the AI snail car communication group
Order one to watch,kiss you!