Note: batch here refers to mini batch
Two methods to realize sequence (text, log) batch processing
- Fixed length batches (uniform length batches)
All batch sequences have the same length. For example, seqs = [[1,2,3,3,4,5,6,7], [1,2,3], [2,4,1,2,3], [1,2,4,1]]
batch_size = 2
Then the maximum sequence length is 8. If it is less than 8, fill it with 0
batch1 = [[1, 2, 3, 3, 4, 5, 6, 7], [1, 2, 3, 0, 0, 0, 0, 0]], batch2 = [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 4, 1, 0, 0, 0, 0]]
- Variable length batches
The sequence length of each batch is the same, and the sequence length between different batches may be different. For example, in the above example, if the length is variable, the sequence length is sorted first, and then padded according to the maximum length of the sequence in each batch
batch1 = [[1, 2, 3, 0], [1, 2, 4, 1]]. # len = 4 batch2 = [[2, 4, 1, 2, 3, 0, 0, 0], [1, 2, 3, 3, 4, 5, 6, 7]] #len = 8
Why use variable length batches?
If there is a very short sequence in the training data, using a uniform length padding will cause the data to be too sparse. It may affect the training time and the prediction effect of the model.
How to implement variable length batches in pytorch?
Answer: dynamic padding
According to the previous idea, there are two main steps to realize dynamic filling
- First sort according to the sequence length
- In each batch, select the maximum length of the sequence or three-quarters of the length divided into points (to prevent extreme cases, the maximum value is very large) as the fixed length of the batch.
collate_fn is an attribute of DataLoader, which is used to process batch data, Introduction to official website
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
from torch.utils.data import Dataset, DataLoader import torch import numpy as np class MyDataset(Dataset): def __init__(self, seq, label): self.seq = seq self.label = label def __len__(self): return len(self.label) def __getitem__(self, index): return self.seq[index], self.label[index] def collate_fn(batch): """ args: batch: [[input_vector, label_vector] for seq in batch] return: [[output_vector]] * batch_size, [[label]]*batch_szie """ percentile = 100 dynamical_pad = True max_len = 50 pad_index = 0 lens = [len(dat) for dat in batch] # find the max len in each batch if dynamical_pad: # dynamical padding seq_len = min(int(np.percentile(lens, percentile)), max_len) # or seq_len = max(lens) else: # fixed length padding seq_len = max_len print("collate_fn seq_len", seq_len) output =  out_label =  for dat in batch: seq = dat[:seq_len] label = dat[:seq_len] padding = [pad_index for _ in range(seq_len - len(seq))] seq.extend(padding) label.extend(padding) output.append(seq) out_label.append(label) output = torch.tensor(output, dtype=torch.long) out_label = torch.tensor(out_label, dtype=torch.long) return output, out_label batch_size = 2 seqs = np.array([[1,0,3,3,4,5,6,0], [1,0,3], [2,4,0,2,3], [1,2,0,1]]) label = np.array([[0,2,0,0,0,0,0,7], [0,2,0], [0,0,1,0,0], [0,0,4,0]]) lens = np.array(list(map(len, seqs))) len_index = np.argsort(-1 * lens) seqs = seqs[len_index] label = label[len_index] mydataset = MyDataset(seqs, label) dl = DataLoader(mydataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) for batch in dl: print(batch)
Code reference 
If you use the RNN based model, you can learn about pack_padded_sequence and pad_ packed_ The two functions of sequence are used not to pass the padding value into the model training.
pack_padded_sequence compresses the filled model, and then transmits the data to the model training. The results of the model are presented by pad_packed_sequence restores the original dimension
#pack_padded_sequence so that padded items in the sequence won't be shown to the LSTM X = torch.nn.utils.rnn.pack_padded_sequence(x, X_lengths, batch_first=True) # now run through LSTM X, self.hidden = self.lstm(X, self.hidden) # undo the packing operation X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True)
See  for Chinese reference
See  for English reference. It also mentioned how to calculate loss with padding. It is recommended to have a look. Some loss functions of pytorch, such as NLLLoss, can specify to ignore the padding value.