// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu // TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory). /** * From PyTorch: * * Copyright (c) 2016- Facebook, Inc (Adam Paszke) * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) * Copyright (c) 2011-2013 NYU (Clement Farabet) * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) * * From Caffe2: * * Copyright (c) 2016-present, Facebook Inc. All rights reserved. * * All contributions by Facebook: * Copyright (c) 2016 Facebook Inc. * * All contributions by Google: * Copyright (c) 2015 Google Inc. * All rights reserved. * * All contributions by Yangqing Jia: * Copyright (c) 2015 Yangqing Jia * All rights reserved. * * All contributions from Caffe: * Copyright(c) 2013, 2014, 2015, the respective contributors * All rights reserved. * * All other contributions: * Copyright(c) 2015, 2016 the respective contributors * All rights reserved. * * Caffe2 uses a copyright model similar to Caffe: each contributor holds * copyright over their contributions to Caffe2. The project versioning records * all such contribution and copyright details. If a contributor wants to further * mark their specific copyright on a particular contribution, they should * indicate their copyright solely in the commit message of the change when it is * committed. * * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America * and IDIAP Research Institute nor the names of its contributors may be * used to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ switch(TYPE) \ { \ case at::ScalarType::Float: \ { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: \ { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: \ { \ using scalar_t_##LEVEL = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } // #else // #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ // switch(TYPE) \ // { \ // case at::ScalarType::Float: \ // { \ // using scalar_t_##LEVEL = float; \ // __VA_ARGS__; \ // break; \ // } \ // case at::ScalarType::Half: \ // { \ // using scalar_t_##LEVEL = at::Half; \ // __VA_ARGS__; \ // break; \ // } \ // default: \ // AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ // } // #endif #define ALIGN_BYTES 16 using Tensor = at::Tensor; using TensorList = at::TensorList; using ScalarType = at::ScalarType; using at::acc_type; template struct LogSoftMaxForwardEpilogue { __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) : logsum(max_input + std::log(sum)) {} __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) : logsum(max_log_sum_exp) {} __device__ __forceinline__ OutT operator()(T input) const { return static_cast(input - logsum); } const AccumT logsum; }; template struct LogSoftMaxBackwardEpilogue { __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) : sum(sum) {} __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { return static_cast(gradOutput - std::exp(static_cast(output)) * sum); } const AccumT sum; }; const int max_threads = 1024; inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { uint64_t block_size = 1; uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); while (block_size < (max_block_size/2)) block_size *= 2; // Launch at least a single warp - the kernel assumes that. block_size = std::max(block_size, static_cast(32)); return dim3(block_size); } template struct Add { __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } }; //////////////////////////////////////////////////////////////////////////////// // Regular kernel (fast when dim_size is large; requires inner_size == 1) //////////////////////////////////////////////////////////////////////////////// template struct MaxFloat { __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { return ::max(max, (AccumT)v); } }; template struct AddFloat { __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + v; } }; template struct SumExpFloat { __device__ __forceinline__ SumExpFloat(AccumT v) : max_k(v) {} __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + std::exp(v - max_k); } const AccumT max_k; }; template class Reduction, typename AccumT> __device__ __forceinline__ AccumT blockReduce(AccumT* smem, AccumT val, const Reduction& r, AccumT defaultVal) { // To avoid RaW races from chaining blockReduce calls together, we need a sync here __syncthreads(); smem[threadIdx.x] = val; __syncthreads(); AccumT warpVal = defaultVal; // First warp will perform per-warp reductions for the remaining warps uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; if (threadIdx.x < 32) { int lane = threadIdx.x % 32; if (lane < blockDim.x / 32) { #pragma unroll for (int i = 0; i < 32; ++i) { warpVal = r(warpVal, smem[lane * 32 + i]); } __syncwarp(mask); smem[lane] = warpVal; } } __syncthreads(); // First thread will perform a reduction of the above per-warp reductions AccumT blockVal = defaultVal; if (threadIdx.x == 0) { for (int i = 0; i < blockDim.x / 32; ++i) { blockVal = r(blockVal, smem[i]); } smem[0] = blockVal; } // Sync and broadcast __syncthreads(); return smem[0]; } template class Reduction1, template class Reduction2, typename AccumT> __device__ __forceinline__ void blockReduce(AccumT* smem, AccumT* reducVal1, AccumT val1, const Reduction1& r1, AccumT defaultVal1, AccumT* reducVal2, AccumT val2, const Reduction2& r2, AccumT defaultVal2) { // To avoid RaW races from chaining blockReduce calls together, we need a sync here __syncthreads(); smem[threadIdx.x] = val1; smem[blockDim.x + threadIdx.x] = val2; __syncthreads(); AccumT warpVal1 = defaultVal1; AccumT warpVal2 = defaultVal2; // First warp will perform per-warp reductions for the remaining warps uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; if (threadIdx.x < 32) { int lane = threadIdx.x % 32; if (lane < blockDim.x / 32) { #pragma unroll for (int i = 0; i < 32; ++i) { warpVal1 = r1(warpVal1, smem[lane * 32 + i]); warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); } __syncwarp(mask); smem[lane] = warpVal1; smem[lane + blockDim.x] = warpVal2; } } __syncthreads(); // First thread will perform a reduction of the above per-warp reductions AccumT blockVal1 = defaultVal1; AccumT blockVal2 = defaultVal2; if (threadIdx.x == 0) { for (int i = 0; i < blockDim.x / 32; ++i) { blockVal1 = r1(blockVal1, smem[i]); blockVal2 = r2(blockVal2, smem[i + blockDim.x]); } smem[0] = blockVal1; smem[blockDim.x] = blockVal2; } // Sync and broadcast __syncthreads(); *reducVal1 = smem[0]; *reducVal2 = smem[blockDim.x]; __syncthreads(); } template class Reduction, int ILP, typename T, typename AccumT> __device__ __forceinline__ AccumT ilpReduce(int shift, T* data, int size, const Reduction& r, AccumT defaultVal) { typedef typename std::aligned_storage::type LoadT; AccumT threadVal = defaultVal; int offset = threadIdx.x; // shift and do 1 if(shift > 0){ data -= shift; size += shift; if(threadIdx.x >= shift){ threadVal = r(threadVal, data[offset]); } size -= blockDim.x; data += blockDim.x; } int last = size % (ILP * blockDim.x); T v[ILP]; LoadT* value = reinterpret_cast(&v); for (; offset * ILP < (size - last); offset += blockDim.x) { *value = reinterpret_cast(data)[offset]; for (int j = 0; j < ILP; ++j) { threadVal = r(threadVal, v[j]); } } offset = size - last + threadIdx.x; // Epilogue for (; offset < size; offset += blockDim.x) threadVal = r(threadVal, data[offset]); return threadVal; } template class Reduction1, template class Reduction2, int ILP, typename T, typename AccumT> __device__ __forceinline__ void ilpReduce(int shift, T* data, int size, AccumT* reducVal1, const Reduction1& r1, AccumT defaultVal1, AccumT* reducVal2, const Reduction2& r2, AccumT defaultVal2) { typedef typename std::aligned_storage::type LoadT; AccumT threadVal1 = defaultVal1; AccumT threadVal2 = defaultVal2; int offset = threadIdx.x; // shift and do 1 if(shift > 0){ data -= shift; size += shift; if(threadIdx.x >= shift){ threadVal1 = r1(threadVal1, data[offset]); threadVal2 = r2(threadVal2, data[offset]); } size -= blockDim.x; data += blockDim.x; } int last = size % (ILP * blockDim.x); T v[ILP]; LoadT* value = reinterpret_cast(&v); for (; offset * ILP < (size - last); offset += blockDim.x) { *value = reinterpret_cast(data)[offset]; for (int j = 0; j < ILP; ++j) { threadVal1 = r1(threadVal1, v[j]); threadVal2 = r2(threadVal2, v[j]); } } offset = size - last + threadIdx.x; // Epilogue for (; offset < size; offset += blockDim.x) { threadVal1 = r1(threadVal1, data[offset]); threadVal2 = r2(threadVal2, data[offset]); } *reducVal1 = threadVal1; *reducVal2 = threadVal2; } template class Epilogue> __global__ void cunn_SoftMaxXEntropyForward( accscalar_t *losses, outscalar_t *max_log_sum_exp, scalar_t *input, int64_t *labels, int64_t classes, const float smoothing, const int total_classes) { extern __shared__ unsigned char smem[]; auto sdata = reinterpret_cast(smem); // forward pointers to batch[blockIdx.x] // each block handles a sample in the mini-batch input += blockIdx.x * classes; //output += blockIdx.x * classes; const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); int64_t label = labels[blockIdx.x]; // find the max and sum accscalar_t threadMax, threadSum, max_k, sum_k; ilpReduce( shift, input, classes, &threadMax, MaxFloat(), -at::numeric_limits::max(), &threadSum, AddFloat(), static_cast(0)); blockReduce( sdata, &max_k, threadMax, Max(), -at::numeric_limits::max(), &sum_k, threadSum, Add(), static_cast(0)); accscalar_t threadExp = ilpReduce(shift, input, classes, SumExpFloat(max_k), static_cast(0)); accscalar_t sumAll = blockReduce( sdata, threadExp, Add(), static_cast(0)); Epilogue epilogue(max_k, sumAll); // calculate per element loss with label smoothing // reserve max + log_sum_exp for bprop if (threadIdx.x == 0) { accscalar_t lse = max_k + std::log(sumAll); accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast(input[label])) : 0.f; losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing); max_log_sum_exp[blockIdx.x] = lse; } } template __device__ __forceinline__ void apply(scalar_t *gradInput, scalar_t *logits, outscalar_t *max_log_sum_exp, outscalar_t *gradOutput, int64_t *labels, const float smoothing, int classes, const int total_classes) { accscalar_t smooth_positives = 1.0 - smoothing; accscalar_t smooth_negatives = smoothing / total_classes; accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; int64_t label = labels[blockIdx.x]; accscalar_t coeff = max_log_sum_exp[blockIdx.x]; int offset = threadIdx.x; int last = classes % (ILP * blockDim.x); for (; offset < classes - last; offset += blockDim.x * ILP) { accscalar_t tmpLogits[ILP]; #pragma unroll for (int j = 0; j < ILP; ++j) { tmpLogits[j] = static_cast(logits[offset + j * blockDim.x]); } #pragma unroll for (int j = 0; j < ILP; ++j) gradInput[offset + j * blockDim.x] = tmpGradOutput * ( std::exp(tmpLogits[j] - coeff) - static_cast( (offset + j * blockDim.x == label) ? 1 : 0) * smooth_positives - smooth_negatives); } for (; offset < classes; offset += blockDim.x) gradInput[offset] = tmpGradOutput * (std::exp( static_cast(logits[offset]) - coeff) - static_cast((offset == label) ? 1 : 0) * smooth_positives - smooth_negatives); } template __device__ __forceinline__ void aligned_apply(int shift, scalar_t *gradInput, scalar_t *logits, outscalar_t *max_log_sum_exp, outscalar_t *gradOutput, int64_t *labels, const float smoothing, int classes, const int total_classes) { accscalar_t smooth_positives = 1.0 - smoothing; accscalar_t smooth_negatives = smoothing / total_classes; accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; int64_t label = labels[blockIdx.x]; accscalar_t coeff = max_log_sum_exp[blockIdx.x]; int offset = threadIdx.x; // shift and do 1 if(shift > 0){ logits -= shift; gradInput -= shift; classes += shift; if(threadIdx.x >= shift){ gradInput[offset] = tmpGradOutput * (std::exp( static_cast(logits[offset]) - coeff) - static_cast(((offset - shift) == label) ? 1 : 0) * smooth_positives - smooth_negatives); } classes -= blockDim.x; gradInput += blockDim.x; logits += blockDim.x; shift -= blockDim.x; } int last = classes % (ILP * blockDim.x); typedef typename std::aligned_storage::type LoadT; // input scalar_t v[ILP]; LoadT* value = reinterpret_cast(&v); // output scalar_t r[ILP]; LoadT* result = reinterpret_cast(&r); for (; offset * ILP < (classes - last); offset += blockDim.x) { *value = reinterpret_cast(logits)[offset]; #pragma unroll for (int j = 0; j < ILP; ++j) { r[j] = tmpGradOutput * (std::exp( static_cast(v[j]) - coeff) - static_cast(((ILP * offset + j - shift) == label) ? 1 : 0) * smooth_positives - smooth_negatives); } reinterpret_cast(gradInput)[offset] = *result; } offset = classes - last + threadIdx.x; for (; offset < classes; offset += blockDim.x) gradInput[offset] = tmpGradOutput * (std::exp( static_cast(logits[offset]) - coeff) - static_cast(((offset - shift) == label) ? 1 : 0) * smooth_positives - smooth_negatives); } template class Epilogue> __global__ void cunn_SoftMaxXEntropyBackward( scalar_t *gradInput, scalar_t *logits, outscalar_t *max_log_sum_exp, outscalar_t *gradOutput, int64_t *labels, const float smoothing, int classes, const int total_classes) { gradInput += blockIdx.x * classes; logits += blockIdx.x * classes; // Do vectorized load/store when input/output have same alignment const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); if (shift == shift_){ aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); } else { apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); } } template class Epilogue> std::vector host_softmax_xentropy( const Tensor & input_, const Tensor & labels_, const float smoothing, const int total_classes) { // For tensor parallel cross entropy with smoothing, we want to pass in the total number // of classes so that smoothing can be applied correctly. If total_classes=-1, use the // last dimension of the input tensor. AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)input_.get_device()}; auto input = input_.contiguous(); Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float)); Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); static_assert(std::is_same, float>::value || std::is_same, double>::value, "accscalar_t for half should be float or double"); AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); const int64_t dim = 1; int64_t outer_size = 1; int64_t dim_size = input.size(dim); int64_t inner_size = 1; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); for (int64_t i = 0; i < dim; ++i) outer_size *= input.size(i); for (int64_t i = dim + 1; i < input.dim(); ++i) inner_size *= input.size(i); // This kernel spawns a block per each element in the batch. // XXX: it assumes that inner_size == 1 TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); dim3 grid(outer_size); using namespace at; DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy", using accscalar_t = at::acc_type; const int ILP = sizeof(float4)/sizeof(scalar_t_0); dim3 block = SoftMax_getBlockSize(ILP, dim_size); cunn_SoftMaxXEntropyForward <<>>( losses.data_ptr(), max_log_sum_exp.data_ptr(), input.data_ptr(), labels_.data_ptr(), dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes ); ); C10_CUDA_CHECK(cudaGetLastError()); std::vector ret = {losses, max_log_sum_exp}; return ret; } template class Epilogue> Tensor host_softmax_xentropy_backward( const at::Tensor &grad_loss, at::Tensor &logits_, const at::Tensor &max_log_sum_exp, const at::Tensor &labels, const float smoothing, bool inplace, const int total_classes) { // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()}; const int64_t dim = 1; Tensor gI = inplace ? logits_ : at::empty_like(logits_); if (grad_loss.numel() == 0) { return gI; } auto grad = grad_loss.contiguous(); auto logits = logits_.contiguous(); static_assert(std::is_same, float>::value || std::is_same, double>::value, "accscalar_t for half should be float or double"); if (grad.dim() == 0) grad = grad.view(1); AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); int64_t outer_size = 1; int64_t dim_size = logits.size(dim); int64_t inner_size = 1; for (int64_t i = 0; i < dim; ++i) outer_size *= logits.size(i); for (int64_t i = dim + 1; i < logits.dim(); ++i) inner_size *= logits.size(i); // See descriptions of kernels above. cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); dim3 grid(outer_size); DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", using accscalar_t = acc_type; const int ILP = sizeof(float4)/sizeof(scalar_t_0); dim3 block = SoftMax_getBlockSize(ILP, dim_size); cunn_SoftMaxXEntropyBackward <<>>( gI.data_ptr(), logits.data_ptr(), max_log_sum_exp.data_ptr(), grad.data_ptr(), labels.data_ptr(), smoothing, dim_size, total_classes ); ); C10_CUDA_CHECK(cudaGetLastError()); return gI; } std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){ return host_softmax_xentropy(input, labels, smoothing, total_classes); } at::Tensor softmax_xentropy_backward_cuda( const at::Tensor &grad_loss, at::Tensor &logits, const at::Tensor &max_log_sum_exp, const at::Tensor &labels, const float smoothing, const bool inplace, const int total_classes) { AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes); }