pytorch collate_fn function to realize variable length batch - Dynamic padding

Note: batch here refers to mini batch

Two methods to realize sequence (text, log) batch processing

  1. Fixed length batches (uniform length batches)
    All batch sequences have the same length. For example, seqs = [[1,2,3,3,4,5,6,7], [1,2,3], [2,4,1,2,3], [1,2,4,1]]
    batch_size = 2
    Then the maximum sequence length is 8. If it is less than 8, fill it with 0
batch1 = [[1, 2, 3, 3, 4, 5, 6, 7], [1, 2, 3, 0, 0, 0, 0, 0]], 
batch2 = [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 4, 1, 0, 0, 0, 0]]
  1. Variable length batches
    The sequence length of each batch is the same, and the sequence length between different batches may be different. For example, in the above example, if the length is variable, the sequence length is sorted first, and then padded according to the maximum length of the sequence in each batch
batch1 = [[1, 2, 3, 0], [1, 2, 4, 1]].  # len = 4
batch2 =  [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 3, 3, 4, 5, 6, 7]] #len = 8

Why use variable length batches?

If there is a very short sequence in the training data, using a uniform length padding will cause the data to be too sparse. It may affect the training time and the prediction effect of the model.

How to implement variable length batches in pytorch?

Answer: dynamic padding
According to the previous idea, there are two main steps to realize dynamic filling

  1. First sort according to the sequence length
  2. In each batch, select the maximum length of the sequence or three-quarters of the length divided into points (to prevent extreme cases, the maximum value is very large) as the fixed length of the batch.
collate_fn parameter

collate_fn is an attribute of DataLoader, which is used to process batch data, Introduction to official website

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

code implementation

from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

class MyDataset(Dataset):
    def __init__(self, seq, label):
        self.seq = seq
        self.label = label

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):
        return self.seq[index], self.label[index]


def collate_fn(batch):
    """
    args:
        batch: [[input_vector, label_vector] for seq in batch]

    return:
        [[output_vector]] * batch_size, [[label]]*batch_szie
    """


    percentile = 100
    dynamical_pad = True
    max_len = 50
    pad_index = 0

    lens = [len(dat[0]) for dat in batch]

    # find the max len in each batch
    if dynamical_pad:
        # dynamical padding
        seq_len = min(int(np.percentile(lens, percentile)), max_len)
        # or seq_len = max(lens)
    else:
        # fixed length padding
        seq_len = max_len
    print("collate_fn seq_len", seq_len)

    output = []
    out_label = []
    for dat in batch:
        seq = dat[0][:seq_len]
        label = dat[1][:seq_len]

        padding = [pad_index for _ in range(seq_len - len(seq))]
        seq.extend(padding)
        label.extend(padding)

        output.append(seq)
        out_label.append(label)

    output = torch.tensor(output, dtype=torch.long)
    out_label = torch.tensor(out_label, dtype=torch.long)

    return output, out_label


batch_size = 2
seqs = np.array([[1,0,3,3,4,5,6,0], [1,0,3], [2,4,0,2,3], [1,2,0,1]])
label = np.array([[0,2,0,0,0,0,0,7], [0,2,0], [0,0,1,0,0], [0,0,4,0]])

lens = np.array(list(map(len, seqs)))
len_index = np.argsort(-1 * lens)
seqs = seqs[len_index]
label = label[len_index]

mydataset = MyDataset(seqs, label)
dl = DataLoader(mydataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
for batch in dl:
    print(batch)


Code reference [1]

If you use the RNN based model, you can learn about pack_padded_sequence and pad_ packed_ The two functions of sequence are used not to pass the padding value into the model training.
pack_padded_sequence compresses the filled model, and then transmits the data to the model training. The results of the model are presented by pad_packed_sequence restores the original dimension

Simple implementation

#pack_padded_sequence so that padded items in the sequence won't be shown to the LSTM
X = torch.nn.utils.rnn.pack_padded_sequence(x, X_lengths, batch_first=True)

# now run through LSTM
X, self.hidden = self.lstm(X, self.hidden)

# undo the packing operation
X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True)

See [2] for Chinese reference
See [3] for English reference. It also mentioned how to calculate loss with padding. It is recommended to have a look. Some loss functions of pytorch, such as NLLLoss, can specify to ignore the padding value.

reference material

[1]https://www.kaggle.com/evilpsycho42/pytorch-batch-dynamic-padding-sort-pack
[2]https://blog.csdn.net/u011550545/article/details/89529977?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param
[3]https://towardsdatascience.com/taming-lstms-variable-sized-mini-batches-and-why-pytorch-is-good-for-your-health-61d35642972e

Tags: AI Pytorch Deep Learning NLP Data Mining

Posted by sonic_2k_uk on Sat, 14 May 2022 05:21:32 +0300