Using TPU to implement ResNet50 in PyTorch

Author | DR. VAIBHAV KUMAR Compile VK Source | Analytics In Diamag

PyTorch has been promoting the development of computer vision and deep learning by providing a large number of powerful tools and technologies.

In the field of computer vision, the execution based on deep learning needs to process a large number of image data sets, so an accelerated environment is needed to speed up the execution process to achieve an acceptable level of accuracy.

PyTorch provides this feature through XLA (accelerated linear algebra), a linear algebra compiler that can target a variety of types of hardware, including GPU and TPU. PyTorch/XLA environment is integrated with Google cloud TPU to achieve faster execution speed.

In this article, we will use TPU in PyTorch to demonstrate the implementation of a deep convolution neural network ResNet50.

The model will be trained and tested in PyTorch/XLA environment to complete the classification task of CIFAR10 data set. We will also check the time spent in 50 epoch training.

Implementation of ResNet50 in pytoch

In order to take advantage of the functions of TPU, this implementation is completed in Google Colab. First, we need to select TPU from the hardware accelerator under Notebook settings.

After selecting TPU, we will verify the environment code using the following line:

import os
assert os.environ['COLAB_TPU_ADDR']

If TPU is enabled, it will execute successfully, otherwise it will throw 'KeyError:' COLAB_TPU_ADDR’’. You can also check the TPU by printing the TPU address.

TPU_Path = 'grpc://'+os.environ['COLAB_TPU_ADDR']
print('TPU Address:', TPU_Path)

In the next step, we will install the XLA environment to speed up the execution process. We implemented convolutional neural network in the last article.

VERSION = "20200516"
!curl -o
!python --version $VERSION

Now we will import all the necessary libraries here.

from matplotlib import pyplot as plt
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms
import time
from google.colab.patches import cv2_imshow
import cv2

After importing the library, we will define and initialize the required parameters.

# Define parameters
FLAGS = {}
FLAGS['data_dir'] = "/tmp/cifar"
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 0.02
FLAGS['momentum'] = 0.9
FLAGS['num_epochs'] = 50
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

In the next step, we will define the ResNet50 model.

class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, in_planes, planes, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(
        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(
        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion * planes:
      self.shortcut = nn.Sequential(
              self.expansion * planes,
              bias=False), nn.BatchNorm2d(self.expansion * planes))

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x)
    out = F.relu(out)
    return out

class ResNet(nn.Module):

  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    self.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_planes, planes, stride))
      self.in_planes = planes * block.expansion
    return nn.Sequential(*layers)

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = torch.flatten(out, 1)
    out = self.linear(out)
    return F.log_softmax(out, dim=1)

def ResNet50():
  return ResNet(BasicBlock, [3, 4, 6, 4, 3])

The following code snippet will define the functions of loading CIFAR10 data set, preparing training and test data set, training process and test process.

SERIAL_EXEC = xmp.MpSerialExecutor()
# Model weights are instantiated only once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(ResNet50())

def train_resnet50():

  def get_dataset():
    norm = transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
    transform_test = transforms.Compose([
    train_dataset = datasets.CIFAR10(
    test_dataset = datasets.CIFAR10(

    return train_dataset, test_dataset

  # Using serial actuators can avoid multiple processes
  # Download the same data.
  train_dataset, test_dataset =

  train_sampler =
  train_loader =
  test_loader =

  # Scale learning rate
  learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model =
  optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                        momentum=FLAGS['momentum'], weight_decay=5e-4)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    for x, (data, target) in enumerate(loader):
      output = model(data)
      loss = loss_fn(output, target)
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.2f} Time={}'.format(xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True)

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Cycle of training and evaluation
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

Now we'll start training for ResNet50. The training will be completed within 50 epoch s defined in the parameters. Before the training, we will record the training time, and after the training, we will print the total time.

start_time = time.time()
# Start the training process
def training(rank, flags):
  global FLAGS
  FLAGS = flags
  accuracy, data, pred, target = train_resnet50()
  if rank == 0:
    # Retrieve the tensor on TPU core 0 and draw it.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

xmp.spawn(training, args=(FLAGS,), nprocs=FLAGS['num_cores'],

After the training, we will print the time spent in the training process.

Finally, in the training process, we visualize the prediction of the model on the sample test data.

end_time = time.time()
print("Time taken = ", end_time-start_time)

Original link:

Welcome to panchuang AI blog:

Official Chinese document of sklearn machine learning:

Welcome to panchuang blog resources summary station:

Posted by XaeroDegreaz on Wed, 25 May 2022 23:01:11 +0300