Image classification dataset (fashion MNIST)
The most commonly used image classification dataset is the handwritten numeral recognition dataset MNIST. However, the classification accuracy of most models on MNIST exceeds 95%. In order to more intuitively observe the differences between algorithms, we will use a data set with more complex image content, fashion MNIST  (this data set is also relatively small, only tens of meters, and computers without GPU can bear it).
In this section, we will use the torchvision package, which serves the PyTorch deep learning framework and is mainly used to build computer vision models. Torchvision is mainly composed of the following parts:
torchvision.datasets: some functions for loading data and common data set interfaces;
torchvision.models: including common model structures (including pre training models), such as AlexNet, VGG, ResNet, etc;
torchvision.transforms: commonly used image transformations, such as cropping, rotation, etc;
torchvision.utils: some other useful methods.
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import sys sys.path.append("..") # To import d2lzh from the upper directory_ pytorch import d2lzh_pytorch as d2l
Next, let's go through torchvision Datasets to download this dataset. The first call will automatically obtain data from the Internet. We specify to obtain training data set or testing data set through the parameter train. Test data set is also called test set, which is only used to evaluate the performance of the model, not to train the model.
In addition, we also specify the parameter transform = transforms Totensor() converts all data into tensor. If it is not converted, it returns a PIL picture. transforms.ToTensor() sets the PIL picture with size (H x W x C) and data in [0, 255] or data type NP NumPy array of uint8 is converted to size (C x H x W) and data type is torch Tensor at [0.0, 1.0] with float32
Note: since the pixel value is an integer from 0 to 255, it happens to be uint8 The range that can be expressed, include transforms.ToTensor()The default input of some functions about pictures is uint8 Type, If not, you may not report an error, but you may not get the desired result. So, if you use pixel values(0-255 integer)If it represents picture data, set its type to uint8， Avoid unnecessary bug.
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
MNIST above_ Train and mnist_test is torch utils. data. The subclass of dataset, so we can use len() to get the size of the dataset, and we can also use subscript to get a specific sample. The number of images in each category in the training set and the test set is 6000 and 1000 respectively. Because there are 10 categories, the sample numbers of training set and test set are 60000 and 10000 respectively.
print(type(mnist_train)) print(len(mnist_train), len(mnist_test))
feature, label = mnist_train print(feature.shape, label) # Channel x Height x Width
The variable feature corresponds to an image with a height and width of 28 pixels. Because we use transforms Totensor(), so the value of each pixel is a 32-bit floating-point number of [0.0, 1.0]. Note that the feature size is (C x H x W), not (H x W x C). The first dimension is the number of channels. Because the data set is a gray image, the number of channels is 1. The latter two dimensions are the height and width of the image.
Fashion MNIST includes 10 categories: t-shirt, trouser, pullover, dress, coat, sandal, shirt, sneaker, bag and ankle boot. The following functions can convert numeric labels to corresponding text labels.
This function has been saved in d2lzh package for future use
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
X, y = ,  for i in range(10): X.append(mnist_train[i]) y.append(mnist_train[i]) show_fashion_mnist(X, get_fashion_mnist_labels(y))