pytorch training model



This article belongs to a series of tutorials on semantic segmentation for pytoch in-depth learning.

The contents of this series of articles are:

  • Basic use of pytoch

  • Explanation of semantic segmentation algorithm

Since wechat does not allow external links, you need to click "read the original text" in the lower left corner of the footer to access the links in the text. All external links in the text have been marked in blue font.


Project background

Deep learning algorithm is nothing more than our method to solve a problem. What kind of network to train, what kind of preprocessing to carry out, and what Loss and optimization methods to adopt are all determined according to the specific task.

So let's take a look at today's task first.

Yes, it is the classic task in UNet paper: medical image segmentation.

Choose it as today's task because it is simple and easy to use.

Briefly describe this task: as shown in the dynamic diagram, give a cell structure diagram, and we will separate each cell from each other.

There are only 30 pieces of training data with a resolution of 512x512. These pictures are electron micrographs of fruit flies.

All right, after the task introduction, start to prepare the training model.


UNet training

To train a deep learning model, you can simply divide it into three steps:

  • Data loading: how to load data, how to define labels, and what data enhancement methods to use are all carried out in this step.

  • Model selection: we have prepared the model, which is the UNet network mentioned in the last article of this series.

  • Algorithm selection: algorithm selection is what loss we choose and what optimization algorithm we use.

Each step is relatively general. We will explain it in combination with today's medical image segmentation task.

1. Data loading

In this step, many things can be done. To put it bluntly, it is nothing more than how to load the picture and how to define the label. In order to increase the robustness of the algorithm or increase the data set, some data enhancement operations can be done.

Since we are dealing with data, let's first look at what the data is like, and then decide how to deal with it.

The data is ready and here it is: (click the original link to learn more)

The data is divided into training set and test set, with 30 pieces each. The training set has labels and the test set has no labels.

The processing of data loading depends on the task and data set. For our segmentation task, we don't need to do much processing. However, due to the small amount of data, only 30 pieces, we can use some data enhancement methods to expand our data set.

Pytorch provides us with a way to load data. We can use this framework to load our data. Look at the pseudo code:

  1. # ================================================================== #
  2. # Input pipeline for custom dataset #
  3. # ================================================================== #
  4. # You should build your custom dataset as below.
  5. class CustomDataset(
  6. def __init__(self):
  7. # TODO
  8. # 1. Initialize file paths or a list of file names.
  9. pass
  10. def __getitem__(self, index):
  11. # TODO
  12. # 1. Read one data from file (e.g. using numpy.fromfile,
  13. # 2. Preprocess the data (e.g. torchvision.Transform).
  14. # 3. Return a data pair (e.g. image and label).
  15. pass
  16. def __len__(self):
  17. # You should change 0 to the total size of your dataset.
  18. return 0
  19. # You can then use the prebuilt data loader.
  20. custom_dataset = CustomDataset()
  21. train_loader =,
  22. batch_size= 64,
  23. shuffle= True)

This is a standard template. We use this template to load data, define labels, and enhance data.

Create a dataset Py file. The code is as follows:

  1. import torch
  2. import cv2
  3. import os
  4. import glob
  5. from import Dataset
  6. import random
  7. class ISBI_Loader(Dataset):
  8. def __init__(self, data_path):
  9. # Initialize the function and read all data_ Picture under path
  10. self.data_path = data_path
  11. self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
  12. def augment(self, image, flipCode):
  13. # Use CV2 Flip for data enhancement. filpCode is 1 horizontal flip, 0 vertical flip, - 1 horizontal + vertical flip
  14. flip = cv2.flip(image, flipCode)
  15. return flip
  16. def __getitem__(self, index):
  17. # Read pictures according to index
  18. image_path = self.imgs_path[index]
  19. # According to image_ Generate label from path_ path
  20. label_path = image_path.replace( 'image', 'label')
  21. # Read training pictures and label pictures
  22. image = cv2.imread(image_path)
  23. label = cv2.imread(label_path)
  24. # Convert data to a single channel picture
  25. image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  26. label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
  27. image = image.reshape( 1, image.shape[ 0], image.shape[ 1])
  28. label = label.reshape( 1, label.shape[ 0], label.shape[ 1])
  29. # Process the label and change the pixel value of 255 to 1
  30. if label.max() > 1:
  31. label = label / 255
  32. # Random data enhancement, no processing when it is 2
  33. flipCode = random.choice([ -1, 0, 1, 2])
  34. if flipCode != 2:
  35. image = self.augment(image, flipCode)
  36. label = self.augment(label, flipCode)
  37. return image, label
  38. def __len__(self):
  39. # Returns the size of the training set
  40. return len(self.imgs_path)
  41. if __name__ == "__main__":
  42. isbi_dataset = ISBI_Loader( "data/train/")
  43. print( "Number of data:", len(isbi_dataset))
  44. train_loader =,
  45. batch_size= 2,
  46. shuffle= True)
  47. for image, label in train_loader:
  48. print(image.shape)

Run the code and you can see the following results:

Explain the code:

__ init__ Function is the initialization function of this class. It reads all picture data according to the specified picture path and stores it in self imgs_ Path list.

__ len__ How much data can the function return? After instantiation of this class, it is called through len() function.

__ getitem__ Function is a data acquisition function. In this function, you can write, read and process data, and some data preprocessing and data enhancement can be carried out here. My processing here is very simple. I just read the picture and process it into a single channel picture. At the same time, because the picture pixels of label are 0 and 255, it needs to be divided by 255 to become 0 and 1. At the same time, random data enhancement is carried out.

The augment function is a defined data enhancement function, which can be processed in any way. I just performed a simple rotation operation here.

In this class, you don't have to do some operations that disrupt the data set, and you don't have to worry about how to read the data according to the batch size. Because after instantiating this class, we can use torch utils. data. The dataloader method specifies the size of the batchsize and determines whether to disrupt the data.

The DataLoader provided by Pytorch is very powerful. We can even specify how many processes are used to load data and whether the data is loaded into CUDA memory. It is not covered in this article, so we won't explain it.

2. Model selection

The model we have chosen is # UNet network structure.

However, we need to fine tune the network. According to the structure of the paper, the size of the model output will be slightly smaller than the size of the image input. If we use the network structure of the paper, we need to do a resize operation after the result output. In order to save this step, we can modify the network so that the output size of the network is exactly equal to the input size of the picture.

Create file, write the following code:

  1. """ Parts of the U-Net model """
  2. """"""
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. class DoubleConv(nn.Module):
  7. """(convolution => [BN] => ReLU) * 2"""
  8. def __init__(self, in_channels, out_channels):
  9. super().__init__()
  10. self.double_conv = nn.Sequential(
  11. nn.Conv2d(in_channels, out_channels, kernel_size= 3, padding= 1),
  12. nn.BatchNorm2d(out_channels),
  13. nn.ReLU(inplace= True),
  14. nn.Conv2d(out_channels, out_channels, kernel_size= 3, padding= 1),
  15. nn.BatchNorm2d(out_channels),
  16. nn.ReLU(inplace= True)
  17. )
  18. def forward(self, x):
  19. return self.double_conv(x)
  20. class Down(nn.Module):
  21. """Downscaling with maxpool then double conv"""
  22. def __init__(self, in_channels, out_channels):
  23. super().__init__()
  24. self.maxpool_conv = nn.Sequential(
  25. nn.MaxPool2d( 2),
  26. DoubleConv(in_channels, out_channels)
  27. )
  28. def forward(self, x):
  29. return self.maxpool_conv(x)
  30. class Up(nn.Module):
  31. """Upscaling then double conv"""
  32. def __init__(self, in_channels, out_channels, bilinear=True):
  33. super().__init__()
  34. # if bilinear, use the normal convolutions to reduce the number of channels
  35. if bilinear:
  36. self.up = nn.Upsample(scale_factor= 2, mode= 'bilinear', align_corners= True)
  37. else:
  38. self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size= 2, stride= 2)
  39. self.conv = DoubleConv(in_channels, out_channels)
  40. def forward(self, x1, x2):
  41. x1 = self.up(x1)
  42. # input is CHW
  43. diffY = torch.tensor([x2.size()[ 2] - x1.size()[ 2]])
  44. diffX = torch.tensor([x2.size()[ 3] - x1.size()[ 3]])
  45. x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
  46. diffY // 2, diffY - diffY // 2])
  47. x =[x2, x1], dim= 1)
  48. return self.conv(x)
  49. class OutConv(nn.Module):
  50. def __init__(self, in_channels, out_channels):
  51. super(OutConv, self).__init__()
  52. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size= 1)
  53. def forward(self, x):
  54. return self.conv(x)

Create file, write the following code:

  1. "" " Full assembly of the parts to form the complete network " ""
  2. "" "Refer" ""
  3. import torch.nn.functional as F
  4. from .unet_parts import *
  5. class UNet(nn.Module):
  6. def __init__(self, n_channels, n_classes, bilinear=True):
  7. super(UNet, self).__init_ _()
  8. self.n_channels = n_channels
  9. self.n_classes = n_classes
  10. self.bilinear = bilinear
  11. = DoubleConv(n_channels, 64)
  12. self.down1 = Down( 64, 128)
  13. self.down2 = Down( 128, 256)
  14. self.down3 = Down( 256, 512)
  15. self.down4 = Down( 512, 512)
  16. self.up1 = Up( 1024, 256, bilinear)
  17. self.up2 = Up( 512, 128, bilinear)
  18. self.up3 = Up( 256, 64, bilinear)
  19. self.up4 = Up( 128, 64, bilinear)
  20. self.outc = OutConv( 64, n_classes)
  21. def forward(self, x):
  22. x1 =
  23. x2 = self.down1(x1)
  24. x3 = self.down2(x2)
  25. x4 = self.down3(x3)
  26. x5 = self.down4(x4)
  27. x = self.up1(x5, x4)
  28. x = self.up2(x, x3)
  29. x = self.up3(x, x2)
  30. x = self.up4(x, x1)
  31. logits = self.outc(x)
  32. return logits
  33. if __name_ _ == '__main__':
  34. net = UNet(n_channels= 3, n_classes= 1)
  35. print(net)

After this adjustment, the output size of the network is the same as the input size of the picture.

3. Algorithm selection

It is very important to choose what Loss is. The quality of Loss selection will affect the effect of the algorithm fitting data.

What Loss to choose is also determined according to the task. Our task today only needs to segment the cell edge, which is a very simple binary classification task, so we can use BCEWithLogitsLoss.

What is BCEWithLogitsLoss? BCEWithLogitsLoss is a function provided by Pytorch to calculate the cross entropy of binary classification.

Its formula is:

Friends who have read my machine learning series tutorials must be familiar with this formula. It is the loss function of Logistic regression. It uses the Sigmoid function threshold in [0,1] to classify.

When the objective function, namely Loss, is determined, how to optimize this objective?

The simplest way is to gradually approach the local extreme value with the familiar gradient descent algorithm.

However, this simple optimization algorithm is slow to solve, that is, it is hard to find the optimal solution.

Various optimization algorithms are gradient descent in essence. For example, the most conventional SGD is the improved random gradient descent algorithm based on gradient descent, and Momentum is the SGD with Momentum introduced to accumulate the historical gradient in the form of exponential attenuation.

In addition to these most basic optimization algorithms, there are also adaptive parameter optimization algorithms. The biggest feature of this kind of algorithm is that each parameter has different learning rates, which can automatically adapt to these learning rates in the whole learning process, so as to achieve better convergence effect.

This paper chooses an adaptive optimization algorithm RMSProp.

Due to the limited space, we will not expand here. It is not enough to write a single article on this optimization algorithm. To understand RMSProp, you must first know what is AdaGrad, because RMSProp is based on the improvement of AdaGrad.

There are also more advanced optimization algorithms than RMSProp, such as Adam, which can be regarded as the modified Momentum+RMSProp algorithm.

In short, for beginners, you only need to know that RMSProp is an adaptive optimization algorithm, which is relatively advanced.

Next, we can start to write the code for training UNet and create train Py write the following code:

  1. from model.unet_model import UNet
  2. from utils.dataset import ISBI_Loader
  3. from torch import optim
  4. import torch.nn as nn
  5. import torch
  6. def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
  7. # Load training set
  8. isbi_dataset = ISBI_Loader(data_path)
  9. train_loader =,
  10. batch_size=batch_size,
  11. shuffle= True)
  12. # Define RMSprop algorithm
  13. optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay= 1e-8, momentum= 0.9)
  14. # Define Loss algorithm
  15. criterion = nn.BCEWithLogitsLoss()
  16. # best_loss statistics, initialized to positive infinity
  17. best_loss = float( 'inf')
  18. # epochs training
  19. for epoch in range(epochs):
  20. # Training mode
  21. net.train()
  22. # According to batch_ Start training
  23. for image, label in train_loader:
  24. optimizer.zero_grad()
  25. # Copy data to device
  26. image =, dtype=torch.float32)
  27. label =, dtype=torch.float32)
  28. # Use the network parameters to output the prediction results
  29. pred = net(image)
  30. # Calculate loss
  31. loss = criterion(pred, label)
  32. print( 'Loss/train', loss.item())
  33. # Save the network parameters with the lowest loss value
  34. if loss < best_loss:
  35. best_loss = loss
  36., 'best_model.pth')
  37. # Update parameters
  38. loss.backward()
  39. optimizer.step()
  40. if __name__ == "__main__":
  41. # Select the device, cuda with cuda, cpu without cuda
  42. device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')
  43. # Load network, picture, single channel 1, classified as 1.
  44. net = UNet(n_channels= 1, n_classes= 1)
  45. # Copy the network to deivce
  47. # Specify the training set address to start training
  48. data_path = "data/train/"
  49. train_net(net, device, data_path)

In order to make the project more clear and concise, we create a model folder and put the code related to the model, that is, our network structure code, and

Create a utils folder and put the code related to the tool, such as the data loading tool dataset py.

This modular management greatly improves the maintainability of the code. can be placed in the root directory of the project. Simply explain the code.

Since there are only 30 pieces of data, we do not distinguish between the training set and the verification set. We save the network parameter with the lowest loss value in the training set as the best model parameter.

If there are no problems, you can see that loss is gradually converging.



After the model is trained, we can use it to see the effect on the test set.

Create predict. In the project root directory Py file, write the following code:

  1. import glob
  2. import numpy as np
  3. import torch
  4. import os
  5. import cv2
  6. from model.unet_model import UNet
  7. if __name__ == "__main__":
  8. # Select the device, cuda with cuda, cpu without cuda
  9. device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')
  10. # Load network, picture, single channel, classified as 1.
  11. net = UNet(n_channels= 1, n_classes= 1)
  12. # Copy the network to deivce
  14. # Load model parameters
  15. net.load_state_dict(torch.load( 'best_model.pth', map_location=device))
  16. # Test mode
  17. net.eval()
  18. # Read all picture paths
  19. tests_path = glob.glob( 'data/test/*.png')
  20.      #Traverse all pictures
  21. for test_path in tests_path:
  22. # Save result address
  23. save_res_path = test_path.split( '.')[ 0] + '_res.png'
  24. # Read picture
  25. img = cv2.imread(test_path)
  26. # Convert to grayscale
  27. img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  28. # Convert to an array with batch = 1, channel = 1 and size of 512 * 512
  29. img = img.reshape( 1, 1, img.shape[ 0], img.shape[ 1])
  30. # Turn to tensor
  31. img_tensor = torch.from_numpy(img)
  32. # Copy the tensor to the device. Only using cpu means copying to cpu, and cuda means copying to cuda.
  33. img_tensor =, dtype=torch.float32)
  34. # forecast
  35. pred = net(img_tensor)
  36. # Extraction results
  37. pred = np.array([ 0])[ 0]
  38. # Processing results
  39. pred[pred >= 0.5] = 255
  40. pred[pred < 0.5] = 0
  41. # Save picture
  42. cv2.imwrite(save_res_path, pred)

After running, you can see the prediction results in the data/test directory:

be accomplished!



  • This paper mainly explains the three steps of training model: data loading, model selection and algorithm selection.

  • This is a simple example. Training normal visual tasks is much more complex. For example, when training a model, you need to choose which model to save according to the accuracy of the model in the verification set; We need to support tensorboard to facilitate us to observe the convergence of loss and so on.

PS: if you think this chapter is helpful to you, please pay attention, comment and praise!

More wonderful content (please click the picture to read)

Official account: AI snail car

Maintain humility, self-discipline and progress

Personal wechat

Note: nickname + school / company + direction

If there is no note, don't pull the group!

Pull you into the AI snail car communication group

Order one to watch,kiss you!


Tags: Deep Learning

Posted by fireant on Sun, 08 May 2022 08:05:32 +0300