# Pytoch implementation of super-resolution network SRCNN

### Overall framework

SR, i.e. super resolution, i.e. super resolution. CNN is relatively famous as convolutional neural network. As can be seen from the name, SRCNN is the first convolutional neural network applied in the field of super-resolution. In fact, it is true.

Super resolution refers to the process of enlarging a low resolution (LR) picture into a high resolution (HR). As it is a pioneering work, SRCNN is relatively simple, which is divided into three steps

1. Input LR image X X 10. After bicubic interpolation, it is enlarged to the target size to obtain Y Y Y
2. The nonlinear mapping is fitted by three-layer convolution network
3. Output HR image results F ( Y ) F(Y) F(Y)

The goal of training is to minimize SR image loss F ( Y ; θ ) F(Y;\theta) F(Y; θ) And original high resolution image X X Mean square error of X-pixel difference

L ( θ ) = 1 n ∑ i = 1 n ∥ F ( Y i ; θ ) − X i ∥ 2 L(\theta)=\frac{1}{n}\sum^n_{i=1}\Vert F(Y_i;\theta)-X_i\Vert^2 L(θ)=n1​i=1∑n​∥F(Yi​;θ)−Xi​∥2

Among them, n n n is the number of training samples, and the parameter update formula is

Δ i + 1 = 0.9 Δ i + η ∂ L ∂ W i l , W i + 1 l = W i l + Δ i + 1 \Delta_{i+1}=0.9\Delta_i+\eta\frac{\partial L}{\partial W^l_i},\quad W^l_{i+1}=W^l_i+\Delta_{i+1} Δi+1​=0.9Δi​+η∂Wil​∂L​,Wi+1l​=Wil​+Δi+1​

### network model

Its network structure is as follows As mentioned earlier, the network is divided into three convolution layers

1. Dimension is 1 × 9 × 9 × 64 1\times9\times9\times64 one × nine × nine × 64 indicates that the number of input image channels is 1 and the kernel size for convolution operation is 9 × 9 9\times9 nine × 9. The output depth is 64.
2. Dimension is 64 × 5 × 5 × 32 64\times5\times5\times32 sixty-four × five × five × 32, 64 is the output of the previous layer, and 32 is the output of the next layer.
3. Dimension is 32 × 5 × 5 × 1 32\times5\times5\times1 thirty-two × five × five × 1. Its output is a single channel image, which is the same as the input.

So this model is easy to implement

# models.py
class SRCNN(nn.Module):
def __init__(self, nChannel=1):
super(SRCNN,self).__init__()
self.conv1 = nn.Conv2d(nChannel, 64,
self.conv2 = nn.Conv2d(64, 32,
self.conv3 = nn.Conv2d(32, nChannel,
self.relu = nn.ReLU(inplace=True)

def forward(self,x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x


### data set

The training data set can be generated manually, and the magnification is set to scale. Considering that the original data may not be divided by scale, it is necessary to re plan the image size, so the generation of training data set is divided into three steps:

1. The original image is resized by bicubic interpolation so that it can be divided by scale as high-resolution image data
2. HR is compressed by bicubic interpolation to scale times, which is the original data of low resolution image
3. The low resolution image is magnified by bicubic interpolation and is equal to the dimension of HR image as low resolution image data LR

Finally, the training data can be partitioned and packaged by h5py, and the generated code is

import h5py
import PIL.Image as pImg

def rgb2gray(img):
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.

# imgPath is the image path; h5Path is the storage path; scale is the magnification
# pSize is the patch size; pStride is the step size
def setTrianData(imgPath, h5Path, scale=3, pSize=33, pStride=14):
h5_file = h5py.File(h5Path, 'w')
lrPatches, hrPatches = [], []       #Used to store low resolution and high resolution patch es
for p in sorted(glob.glob(f'{imgPath}/*')):
hr = pImg.open(p).convert('RGB')
lrWidth, lrHeight = hr.width // scale, hr.height // scale
# Width and height are the training data size divisible by scale
width, height = lrWidth*scale, lrHeight*scale
hr = hr.resize((width, height), resample=pImg.BICUBIC)
lr = hr.resize((lrWidth, lrHeight), resample=pImg.BICUBIC)
lr = lr.resize((width, height), resample=pImg.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = rgb2gray(hr)
lr = rgb2gray(lr)
# Split data
for i in range(0, height - pSize + 1, pStride):
for j in range(0, width - pSize + 1, pStride):
lrPatches.append(lr[i:i + pSize, j:j + pSize])
hrPatches.append(hr[i:i + pSize, j:j + pSize])
h5_file.create_dataset('lr', data=np.array(lrPatches))
h5_file.create_dataset('hr', data=np.array(hrPatches))
h5_file.close()


Taking the common T91 data set as an example, a 181M h5 file can be obtained through the above method.

Do the same for the forecast data.

After completing the training data, you need to create a reading class for these data so that the DataLoader in torch can call, and the content in the DataLoader is Dataset, so the new reading class needs to inherit the Dataset and implement its__ getitem__ And__ len__ These two member methods.

These two methods just look scary, but with a little in-depth understanding of Python, you will know__ getitem__ Is the method of dictionary indexing, and__ len__ Set the return value of len function.

import h5py
import numpy as np
from torch.utils.data import Dataset

class DataSet(Dataset):
def __init__(self, h5_file):
super(Dataset, self).__init__()
self.h5_file = h5_file

def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return len(f['lr'])


### train

First of all, the training needs a little preparation, such as the data set is ready, the relevant folders need to be built, and what kind of optimization method needs to be adopted after the model is built. Whether the training device uses cpu or cuda, and then load the data set and model on the device.

Data preparation

import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from models import SRCNN

trainFile = "91-image.h5"
evalFile = "Set5.h5"

cudnn.benchmark = True
# Set whether the training device is CPU or cuda
device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')

trainData = Dataset(trainFile)
bSize=bSize,
shuffle=True,               # Indicates that the sample is disturbed
pin_memory=True,            # Easy to load CUDA
drop_last=True)

evalDatas = Dataset(evalFile)


Model preparation

# Models and equipment
lr = 1e-4       #Learning rate
torch.manual_seed(seed)     #Set random number seed
model = SRCNN().to(device)  #Load model into device
criterion = nn.MSELoss()    #Set loss function
{'params': model.conv1.parameters()},
{'params': model.conv2.parameters()},
{'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)


train

outPath = "outputs"
scale = 3
bSize = 16
nEpoch = 400
nWorker = 8     #Number of threads
seed = 42       #Random number seed

def initPSNR():
return {'avg':0, 'sum':0, 'count':0}

def updatePSNR(psnr, val, n=1):
s = psnr['sum'] + val*n
c = psnr['count'] + n
return {'avg':s/c, 'sum':s, 'count':c}

bestWeights = copy.deepcopy(model.state_dict()) #Best model
bestEpoch = 0   #Best training results
bestPSNR = 0.0  #Optimal psnr

# Training main cycle
for epoch in range(nEpoch):
model.train()
epochLosses = initPSNR()

inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
loss = criterion(preds, labels)
epochLosses = updatePSNR(epochLosses,loss.item(), len(inputs))
loss.backward()         #Back propagation
optimizer.step()        #Update network parameters according to gradient
print(f'{epochLosses['avg']:.6f}')

torch.save(model.state_dict(),
os.path.join(outPath, f'epoch_{epoch}.pth'))

model.eval()    #Cancel dropout
psnr = AverageMeter()

inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
# Order reqires_grad is automatically set to False to turn off automatic derivation
# clamp normalizes inputs to a range of 0 to 1
preds = model(inputs).clamp(0.0, 1.0)

tmp_psnr = 10. * torch.log10(
1. / torch.mean((preds - labels) ** 2))
psnr = updatePSNR(psnr, tmp_psnr, len(inputs))

print(f'eval psnr: {psnr.avg:.2f}')

if psnr['avg'] > bestPSNR:
bestEpoch = epoch
bestPSNR = psnr['avg']
bestWeights = copy.deepcopy(model.state_dict())

print(f'best epoch: {bestEpoch}, psnr: {bestPSNR:.2f}')
torch.save(bestWeights, os.path.join(outPath, 'best.pth'))


The final result is Posted by thefollower on Thu, 05 May 2022 05:02:16 +0300