preface
Previously, how can our company implement an efficient Softmax CUDA kernel? Some details are still not understood. It happens that a similar Reduce+Scale Kernel will be built recently. The principle and mechanism are still relatively similar, so turn it out and re understand it.
background
We define a ReduceScale operation:
Assuming Tensor is (N, C), first calculate the absMax value in the dimension of C, and we record it as scale, then divide each line by its own scale, and finally output it.
A simple piece of numpy code is as follows:
copyimport numpy as np N = 1000 C = 128 x = np.random.randn(N, C) scale = np.expand_dims(np.max(np.abs(x), axis=1), 1) out = x / scale print(out.shape)
BaseLine
Here, BaseLine directly calls BlockReduce in the cub library. A threadBlock processes a line of data, calculates AbsMaxVal, and then scales. The code is as follows:
copy#include "cuda.h" #include "cub/cub.cuh" constexpr int kReduceBlockSize = 128; template<typename T> __device__ T abs_func(const T& a) { return abs(a); } template<typename T> __device__ T max_func(const T a, const T b) { return a > b ? a : b; } template<typename T> struct AbsMaxOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max_func(abs_func(a), abs_func(b)); } }; template<typename T> __inline__ __device__ T BlockAllReduceAbsMax(T val) { typedef cub::BlockReduce<T, kReduceBlockSize> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T final_result; T result = BlockReduce(temp_storage).Reduce(val, AbsMaxOp<T>()); if (threadIdx.x == 0) { final_result = result; } __syncthreads(); return final_result; } template<typename T, typename IDX> __global__ void ReduceScaleBlockKernel(T* x, IDX row_size, IDX col_size) { for(int32_t row = blockIdx.x, step=gridDim.x; row < row_size; row+= step){ T thread_scale_factor = 0.0; for(int32_t col=threadIdx.x; col < col_size; col+= blockDim.x){ IDX idx = row * col_size + col; T x_val = x[idx]; thread_scale_factor = max_func(thread_scale_factor, abs_func(x_val)); } T row_scale_factor = BlockAllReduceAbsMax<T>(thread_scale_factor); for(int32_t col=threadIdx.x; col < col_size; col+=blockDim.x){ IDX idx = row * col_size + col; x[idx] /= row_scale_factor; } } }
In the parameter, x is the input data, row_size is the number of rows, col_size is the size of the column
The test machine is in A100 40GB. In order to make the difference between the results more obvious, we set the number of rows to be large. The input shape is (55296 * 8, 128). The number of thread blocks started depends on how to set the grid in CUDA Kernel_ Size and block_size? This article specifies that the setting here is rough (55296, 128) and the data type is Float. Then let's look at the results of ncu:

It mainly has these indicators. The time-consuming is 577.95us and the throughput is 748.78Gb/s
Let's analyze it step by step according to the points mentioned in the article on Softmax Optimization:
Optimize 1 data Pack
In the previous efficient, easy-to-use and expandable, I want it all: the design optimization idea of OneFlow CUDA Elementwise template library describes in detail how to do vector reading and writing. cuda supports 128bit reading and writing at most. Then when the data type is Float, we can package four consecutive floats together, read and write at one time, and improve throughput.
Readers who have learned about this should respond. Well, there is a type called float4 in CUDA that does this. Yes, but in order to more flexibly support the vectorization of other data types, we use the characteristics of union shared space to implement a Pack class:
copytemplate<typename T, int N> struct GetPackType { using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type; }; template<typename T, int N> using PackType = typename GetPackType<T, N>::type; template<typename T, int N> union Pack { static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, ""); __device__ Pack() { // do nothing } PackType<T, N> storage; T elem[N]; };
Optimize 2 data cache
The whole operator logic needs to read the data once, calculate the scale, and then read the data again and use the scale to scale. Obviously, we read the data twice here, and the data is placed in Global Memory. The bandwidth is relatively low, which will lead to time-consuming reading.


A natural idea is to cache in the register / Shared Memory. Since we only implement the WarpReduce version here, we cache it into the register (for other versions, please refer to the optimization Softmax article at the beginning) to reduce one reading of Global Memory.
copytemplate<typename T, typename IDX, int pack_size, int cols_per_thread> __global__ void ReduceScaleWarpKernel(T* x, IDX row_size, IDX col_size) { // ... T buf[cols_per_thread]; // ...
Optimization 3 using Warp to process a row of data
Compared with BaseLine, we use warp as the unit of Reduce. First, let's take a brief look at the implementation of WarpReduce.
copytemplate<typename T> struct AbsMaxOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max_func(abs_func(a), abs_func(b)); } }; template<typename T> __inline__ __device__ T WarpAbsMaxAllReduce(T val){ for(int lane_mask = kWarpSize/2; lane_mask > 0; lane_mask /= 2){ val = AbsMaxOp<T>()(val, __shfl_xor_sync(0xffffffff, val, lane_mask)); } return val; }
This code is often seen in other BlockReduce. It is with the help of__ shfl_xor_sync to achieve comparison. The shuffle instruction allows two threads of the same thread bundle to directly read each other's registers.
copyT __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize);
Mask is a mask for threads. Generally, all threads should participate in the calculation, so mask is 0xffffffff
var is the register value and laneMask is the mask used for bitwise XOR
A concept called Lane is introduced here, which represents the thread number in the thread bundle
The schematic diagram is as follows:

When laneMask = 16, its binary is 00010000, and then each thread of the thread bundle performs XOR operation with laneMask
For example:
- 0000 0000 xor 0001 0000 = 0001 0000 = 16
- 0000 0001 xor 0001 0000 = 0001 0001 = 17
- 0000 0010 xor 0001 0000 = 0001 0010 = 18
By analogy, we finally get an absmax value in Warp.
Next, we start to write Kernel. The template parameters are:
- T data type
- IDX index type
- pack_ The number of size packs. For example, float can be divided into 4 packs, which corresponds to pack_size=4
- cols_per_thread the number of elements that each thread needs to process. For example, the size of a row is 128, and our warp has 32 threads, so here is 128 / 32 = 4
copytemplate<typename T, typename IDX, int pack_size, int cols_per_thread> __global__ void ReduceScaleWarpKernel(T* x, IDX row_size, IDX col_size) { // ... }
Like BaseLine, our block size is still set to 128 threads, and a warp is 32 threads, so we can organize a block into (32, 4) and contain 4 warps.

According to this hierarchy, we can calculate:
- global_ thread_ group_ Global index ID of the current warp
- num_ total_ thread_ Total number of group warps
- lane_id thread id in thread bundle
- num_ The number of packs, that is, the number of elements to be processed by each thread / pack_size
copyconst int32_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int32_t num_total_thread_group = gridDim.x * blockDim.y; const int32_t lane_id = threadIdx.x; using LoadStoreType = PackType<T, pack_size>; using LoadStorePack = Pack<T, pack_size>; T buf[cols_per_thread]; constexpr int num_packs = cols_per_thread / pack_size;
Since the number of started warp s is less than the number of rows, we will introduce a for loop.
Assuming our cols = 256, each thread in the thread bundle needs to process 256 / 32 = 8 elements, and four float s can be packed together. Therefore, each thread in our thread bundle needs to process two packs, so we also need to introduce a about num_ for loop of packs to ensure that the whole line is read:

After reading a pack at one time, we put it into the register to cache and calculate the AbsMaxVal on the thread.
copyfor(IDX row_idx = global_thread_group_id; row_idx < row_size; row_idx += num_total_thread_group){ T thread_abs_max_val = 0.0; for(int pack_idx = 0; pack_idx < num_packs; pack_idx++){ const int32_t pack_offset = pack_idx * pack_size; const int32_t col_offset = pack_idx * kWarpSize * pack_size + lane_id * pack_size; const int32_t load_offset = (row_idx * col_size + col_offset) / pack_size; LoadStorePack load_pack; load_pack.storage = *(reinterpret_cast<LoadStoreType*>(x)+ load_offset); #pragma unroll for(int i = 0; i < pack_size; i++){ buf[pack_offset] = load_pack.elem[i]; thread_abs_max_val = max_func(thread_abs_max_val, abs_func(buf[pack_offset])); } }
Then we call WarpAbsMaxAllReduce to reduce, obtain AbsMaxVal in the thread bundle, and scale the cached data numerically.
copyT warp_max_val = WarpAbsMaxAllReduce<T>(thread_abs_max_val); #pragma unroll for (int col = 0; col < cols_per_thread; col++) { buf[col] = buf[col] / warp_max_val; }
Finally, similar to reading at the beginning, we write back the value in the register, and the calculation logic of relevant indexes is consistent:
copyfor(int pack_idx = 0; pack_idx < num_packs; pack_idx++){ const int32_t pack_offset = pack_idx * pack_size; const int32_t col_offset = pack_idx * pack_size * kWarpSize + lane_id * pack_size; const int32_t store_offset = (row_idx * col_size + col_offset) / pack_size; LoadStorePack store_pack; #pragma unroll for(int i = 0; i < pack_size; i++){ store_pack.elem[i] = buf[pack_offset + i]; } *(reinterpret_cast<LoadStoreType*>(x)+ store_offset) = store_pack.storage; }
The complete code is as follows:
copytemplate<typename T> __inline__ __device__ T WarpAbsMaxAllReduce(T val){ for(int lane_mask = kWarpSize/2; lane_mask > 0; lane_mask /= 2){ val = AbsMaxOp<T>()(val, __shfl_xor_sync(0xffffffff, val, lane_mask)); } return val; } template<typename T, typename IDX, int pack_size, int cols_per_thread> __global__ void ReduceScaleWarpKernel(T* x, IDX row_size, IDX col_size) { const int32_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; const int32_t num_total_thread_group = gridDim.x * blockDim.y; const int32_t lane_id = threadIdx.x; using LoadStoreType = PackType<T, pack_size>; using LoadStorePack = Pack<T, pack_size>; T buf[cols_per_thread]; constexpr int num_packs = cols_per_thread / pack_size; for(IDX row_idx = global_thread_group_id; row_idx < row_size; row_idx += num_total_thread_group){ T thread_abs_max_val = 0.0; for(int pack_idx = 0; pack_idx < num_packs; pack_idx++){ const int32_t pack_offset = pack_idx * pack_size; const int32_t col_offset = pack_idx * kWarpSize * pack_size + lane_id * pack_size; const int32_t load_offset = (row_idx * col_size + col_offset) / pack_size; LoadStorePack load_pack; load_pack.storage = *(reinterpret_cast<LoadStoreType*>(x)+ load_offset); #pragma unroll for(int i = 0; i < pack_size; i++){ buf[pack_offset] = load_pack.elem[i]; thread_abs_max_val = max_func(thread_abs_max_val, abs_func(buf[pack_offset])); } } T warp_max_val = WarpAbsMaxAllReduce<T>(thread_abs_max_val); #pragma unroll for (int col = 0; col < cols_per_thread; col++) { buf[col] = buf[col] / warp_max_val; } for(int pack_idx = 0; pack_idx < num_packs; pack_idx++){ const int32_t pack_offset = pack_idx * pack_size; const int32_t col_offset = pack_idx * pack_size * kWarpSize + lane_id * pack_size; const int32_t store_offset = (row_idx * col_size + col_offset) / pack_size; LoadStorePack store_pack; #pragma unroll for(int i = 0; i < pack_size; i++){ store_pack.elem[i] = buf[pack_offset + i]; } *(reinterpret_cast<LoadStoreType*>(x)+ store_offset) = store_pack.storage; } } }
Here we are convenient to test. When calling, we write some template parameters directly
copyconstexpr int cols_per_thread = 128 / kWarpSize; ReduceScaleWarpKernel<float, int32_t, 4, cols_per_thread><<<55296, block_dim>>>(device_ptr, row_size, col_size);
Finally, let's take a look at the results of ncu:

The throughput reaches 1.3T and the time bit is 333us, which is 73% faster than BaseLine.
summary
For more special cases, you can refer to the code optimized by Softmax. Only the first warp calculation method is implemented here. I think it looks ok. It's still a little difficult to understand when I write. I hope this blog can help readers understand the use of warp.