123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758 |
- // Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu
- // TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory).
- /**
- * From PyTorch:
- *
- * Copyright (c) 2016- Facebook, Inc (Adam Paszke)
- * Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
- * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
- * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
- * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
- * Copyright (c) 2011-2013 NYU (Clement Farabet)
- * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
- * Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
- * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
- *
- * From Caffe2:
- *
- * Copyright (c) 2016-present, Facebook Inc. All rights reserved.
- *
- * All contributions by Facebook:
- * Copyright (c) 2016 Facebook Inc.
- *
- * All contributions by Google:
- * Copyright (c) 2015 Google Inc.
- * All rights reserved.
- *
- * All contributions by Yangqing Jia:
- * Copyright (c) 2015 Yangqing Jia
- * All rights reserved.
- *
- * All contributions from Caffe:
- * Copyright(c) 2013, 2014, 2015, the respective contributors
- * All rights reserved.
- *
- * All other contributions:
- * Copyright(c) 2015, 2016 the respective contributors
- * All rights reserved.
- *
- * Caffe2 uses a copyright model similar to Caffe: each contributor holds
- * copyright over their contributions to Caffe2. The project versioning records
- * all such contribution and copyright details. If a contributor wants to further
- * mark their specific copyright on a particular contribution, they should
- * indicate their copyright solely in the commit message of the change when it is
- * committed.
- *
- * All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- *
- * 1. Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- *
- * 2. Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- *
- * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
- * and IDIAP Research Institute nor the names of its contributors may be
- * used to endorse or promote products derived from this software without
- * specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
- * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
- * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
- * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
- * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
- * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
- * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
- * POSSIBILITY OF SUCH DAMAGE.
- */
- #include <ATen/ATen.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/cuda/CUDAGuard.h>
- #include <ATen/AccumulateType.h>
- #include <ATen/cuda/NumericLimits.cuh>
- // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
- // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
- switch(TYPE) \
- { \
- case at::ScalarType::Float: \
- { \
- using scalar_t_##LEVEL = float; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Half: \
- { \
- using scalar_t_##LEVEL = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: \
- { \
- using scalar_t_##LEVEL = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
- // #else
- // #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
- // switch(TYPE) \
- // { \
- // case at::ScalarType::Float: \
- // { \
- // using scalar_t_##LEVEL = float; \
- // __VA_ARGS__; \
- // break; \
- // } \
- // case at::ScalarType::Half: \
- // { \
- // using scalar_t_##LEVEL = at::Half; \
- // __VA_ARGS__; \
- // break; \
- // } \
- // default: \
- // AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- // }
- // #endif
- #define ALIGN_BYTES 16
- using Tensor = at::Tensor;
- using TensorList = at::TensorList;
- using ScalarType = at::ScalarType;
- using at::acc_type;
- template<typename T, typename AccumT, typename OutT>
- struct LogSoftMaxForwardEpilogue {
- __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
- : logsum(max_input + std::log(sum)) {}
- __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)
- : logsum(max_log_sum_exp) {}
- __device__ __forceinline__ OutT operator()(T input) const {
- return static_cast<OutT>(input - logsum);
- }
- const AccumT logsum;
- };
- template<typename T, typename AccumT, typename OutT>
- struct LogSoftMaxBackwardEpilogue {
- __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
- : sum(sum) {}
- __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
- return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
- }
- const AccumT sum;
- };
- const int max_threads = 1024;
- inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
- uint64_t block_size = 1;
- uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
- while (block_size < (max_block_size/2)) block_size *= 2;
- // Launch at least a single warp - the kernel assumes that.
- block_size = std::max(block_size, static_cast<uint64_t>(32));
- return dim3(block_size);
- }
- template<typename T>
- struct Add {
- __device__ __forceinline__ T operator()(T a, T b) const {
- return a + b;
- }
- };
- template<typename T>
- struct Max {
- __device__ __forceinline__ T operator()(T a, T b) const {
- return a < b ? b : a;
- }
- };
- ////////////////////////////////////////////////////////////////////////////////
- // Regular kernel (fast when dim_size is large; requires inner_size == 1)
- ////////////////////////////////////////////////////////////////////////////////
- template <typename T, typename AccumT>
- struct MaxFloat
- {
- __device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
- return ::max(max, (AccumT)v);
- }
- };
- template<typename T, typename AccumT>
- struct AddFloat
- {
- __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
- return sum + v;
- }
- };
- template<typename T, typename AccumT>
- struct SumExpFloat
- {
- __device__ __forceinline__ SumExpFloat(AccumT v)
- : max_k(v) {}
- __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
- return sum + std::exp(v - max_k);
- }
- const AccumT max_k;
- };
- template <template<typename> class Reduction, typename AccumT>
- __device__ __forceinline__ AccumT
- blockReduce(AccumT* smem, AccumT val,
- const Reduction<AccumT>& r,
- AccumT defaultVal)
- {
- // To avoid RaW races from chaining blockReduce calls together, we need a sync here
- __syncthreads();
- smem[threadIdx.x] = val;
- __syncthreads();
- AccumT warpVal = defaultVal;
- // First warp will perform per-warp reductions for the remaining warps
- uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
- if (threadIdx.x < 32) {
- int lane = threadIdx.x % 32;
- if (lane < blockDim.x / 32) {
- #pragma unroll
- for (int i = 0; i < 32; ++i) {
- warpVal = r(warpVal, smem[lane * 32 + i]);
- }
- __syncwarp(mask);
- smem[lane] = warpVal;
- }
- }
- __syncthreads();
- // First thread will perform a reduction of the above per-warp reductions
- AccumT blockVal = defaultVal;
- if (threadIdx.x == 0) {
- for (int i = 0; i < blockDim.x / 32; ++i) {
- blockVal = r(blockVal, smem[i]);
- }
- smem[0] = blockVal;
- }
- // Sync and broadcast
- __syncthreads();
- return smem[0];
- }
- template <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>
- __device__ __forceinline__ void
- blockReduce(AccumT* smem,
- AccumT* reducVal1,
- AccumT val1,
- const Reduction1<AccumT>& r1,
- AccumT defaultVal1,
- AccumT* reducVal2,
- AccumT val2,
- const Reduction2<AccumT>& r2,
- AccumT defaultVal2)
- {
- // To avoid RaW races from chaining blockReduce calls together, we need a sync here
- __syncthreads();
- smem[threadIdx.x] = val1;
- smem[blockDim.x + threadIdx.x] = val2;
- __syncthreads();
- AccumT warpVal1 = defaultVal1;
- AccumT warpVal2 = defaultVal2;
- // First warp will perform per-warp reductions for the remaining warps
- uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
- if (threadIdx.x < 32) {
- int lane = threadIdx.x % 32;
- if (lane < blockDim.x / 32) {
- #pragma unroll
- for (int i = 0; i < 32; ++i) {
- warpVal1 = r1(warpVal1, smem[lane * 32 + i]);
- warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);
- }
- __syncwarp(mask);
- smem[lane] = warpVal1;
- smem[lane + blockDim.x] = warpVal2;
- }
- }
- __syncthreads();
- // First thread will perform a reduction of the above per-warp reductions
- AccumT blockVal1 = defaultVal1;
- AccumT blockVal2 = defaultVal2;
- if (threadIdx.x == 0) {
- for (int i = 0; i < blockDim.x / 32; ++i) {
- blockVal1 = r1(blockVal1, smem[i]);
- blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
- }
- smem[0] = blockVal1;
- smem[blockDim.x] = blockVal2;
- }
- // Sync and broadcast
- __syncthreads();
- *reducVal1 = smem[0];
- *reducVal2 = smem[blockDim.x];
- __syncthreads();
- }
- template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
- __device__ __forceinline__ AccumT
- ilpReduce(int shift,
- T* data,
- int size,
- const Reduction<T, AccumT>& r,
- AccumT defaultVal)
- {
- typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
- AccumT threadVal = defaultVal;
- int offset = threadIdx.x;
- // shift and do 1
- if(shift > 0){
- data -= shift;
- size += shift;
- if(threadIdx.x >= shift){
- threadVal = r(threadVal, data[offset]);
- }
- size -= blockDim.x;
- data += blockDim.x;
- }
- int last = size % (ILP * blockDim.x);
- T v[ILP];
- LoadT* value = reinterpret_cast<LoadT*>(&v);
- for (; offset * ILP < (size - last); offset += blockDim.x) {
- *value = reinterpret_cast<LoadT*>(data)[offset];
- for (int j = 0; j < ILP; ++j) {
- threadVal = r(threadVal, v[j]);
- }
- }
- offset = size - last + threadIdx.x;
- // Epilogue
- for (; offset < size; offset += blockDim.x)
- threadVal = r(threadVal, data[offset]);
- return threadVal;
- }
- template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
- __device__ __forceinline__ void
- ilpReduce(int shift,
- T* data,
- int size,
- AccumT* reducVal1,
- const Reduction1<T, AccumT>& r1,
- AccumT defaultVal1,
- AccumT* reducVal2,
- const Reduction2<T, AccumT>& r2,
- AccumT defaultVal2)
- {
- typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
- AccumT threadVal1 = defaultVal1;
- AccumT threadVal2 = defaultVal2;
- int offset = threadIdx.x;
- // shift and do 1
- if(shift > 0){
- data -= shift;
- size += shift;
- if(threadIdx.x >= shift){
- threadVal1 = r1(threadVal1, data[offset]);
- threadVal2 = r2(threadVal2, data[offset]);
- }
- size -= blockDim.x;
- data += blockDim.x;
- }
- int last = size % (ILP * blockDim.x);
- T v[ILP];
- LoadT* value = reinterpret_cast<LoadT*>(&v);
- for (; offset * ILP < (size - last); offset += blockDim.x) {
- *value = reinterpret_cast<LoadT*>(data)[offset];
- for (int j = 0; j < ILP; ++j) {
- threadVal1 = r1(threadVal1, v[j]);
- threadVal2 = r2(threadVal2, v[j]);
- }
- }
- offset = size - last + threadIdx.x;
- // Epilogue
- for (; offset < size; offset += blockDim.x) {
- threadVal1 = r1(threadVal1, data[offset]);
- threadVal2 = r2(threadVal2, data[offset]);
- }
- *reducVal1 = threadVal1;
- *reducVal2 = threadVal2;
- }
- template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
- __global__ void
- cunn_SoftMaxXEntropyForward(
- accscalar_t *losses,
- outscalar_t *max_log_sum_exp,
- scalar_t *input,
- int64_t *labels,
- int64_t classes,
- const float smoothing,
- const int total_classes)
- {
- extern __shared__ unsigned char smem[];
- auto sdata = reinterpret_cast<accscalar_t*>(smem);
- // forward pointers to batch[blockIdx.x]
- // each block handles a sample in the mini-batch
- input += blockIdx.x * classes;
- //output += blockIdx.x * classes;
- const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
- int64_t label = labels[blockIdx.x];
- // find the max and sum
- accscalar_t threadMax, threadSum, max_k, sum_k;
- ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
- shift, input, classes,
- &threadMax, MaxFloat<scalar_t, accscalar_t>(),
- -at::numeric_limits<accscalar_t>::max(),
- &threadSum, AddFloat<scalar_t, accscalar_t>(),
- static_cast<accscalar_t>(0));
- blockReduce<Max, Add, accscalar_t>(
- sdata,
- &max_k, threadMax, Max<accscalar_t>(),
- -at::numeric_limits<accscalar_t>::max(),
- &sum_k, threadSum, Add<accscalar_t>(),
- static_cast<accscalar_t>(0));
- accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
- accscalar_t sumAll = blockReduce<Add, accscalar_t>(
- sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
- Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
- // calculate per element loss with label smoothing
- // reserve max + log_sum_exp for bprop
- if (threadIdx.x == 0) {
- accscalar_t lse = max_k + std::log(sumAll);
- accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast<accscalar_t>(input[label])) : 0.f;
- losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing);
- max_log_sum_exp[blockIdx.x] = lse;
- }
- }
- template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
- __device__ __forceinline__ void
- apply(scalar_t *gradInput,
- scalar_t *logits,
- outscalar_t *max_log_sum_exp,
- outscalar_t *gradOutput,
- int64_t *labels,
- const float smoothing,
- int classes,
- const int total_classes)
- {
- accscalar_t smooth_positives = 1.0 - smoothing;
- accscalar_t smooth_negatives = smoothing / total_classes;
- accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
- int64_t label = labels[blockIdx.x];
- accscalar_t coeff = max_log_sum_exp[blockIdx.x];
- int offset = threadIdx.x;
- int last = classes % (ILP * blockDim.x);
- for (; offset < classes - last; offset += blockDim.x * ILP) {
- accscalar_t tmpLogits[ILP];
- #pragma unroll
- for (int j = 0; j < ILP; ++j) {
- tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);
- }
- #pragma unroll
- for (int j = 0; j < ILP; ++j)
- gradInput[offset + j * blockDim.x] = tmpGradOutput * (
- std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
- (offset + j * blockDim.x == label) ? 1 : 0) *
- smooth_positives - smooth_negatives);
- }
- for (; offset < classes; offset += blockDim.x)
- gradInput[offset] = tmpGradOutput * (std::exp(
- static_cast<accscalar_t>(logits[offset]) - coeff) -
- static_cast<accscalar_t>((offset == label) ? 1 : 0) *
- smooth_positives - smooth_negatives);
- }
- template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
- __device__ __forceinline__ void
- aligned_apply(int shift,
- scalar_t *gradInput,
- scalar_t *logits,
- outscalar_t *max_log_sum_exp,
- outscalar_t *gradOutput,
- int64_t *labels,
- const float smoothing,
- int classes,
- const int total_classes)
- {
- accscalar_t smooth_positives = 1.0 - smoothing;
- accscalar_t smooth_negatives = smoothing / total_classes;
- accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
- int64_t label = labels[blockIdx.x];
- accscalar_t coeff = max_log_sum_exp[blockIdx.x];
- int offset = threadIdx.x;
- // shift and do 1
- if(shift > 0){
- logits -= shift;
- gradInput -= shift;
- classes += shift;
- if(threadIdx.x >= shift){
- gradInput[offset] = tmpGradOutput * (std::exp(
- static_cast<accscalar_t>(logits[offset]) - coeff) -
- static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
- smooth_positives - smooth_negatives);
- }
- classes -= blockDim.x;
- gradInput += blockDim.x;
- logits += blockDim.x;
- shift -= blockDim.x;
- }
- int last = classes % (ILP * blockDim.x);
- typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;
- // input
- scalar_t v[ILP];
- LoadT* value = reinterpret_cast<LoadT*>(&v);
- // output
- scalar_t r[ILP];
- LoadT* result = reinterpret_cast<LoadT*>(&r);
- for (; offset * ILP < (classes - last); offset += blockDim.x) {
- *value = reinterpret_cast<LoadT*>(logits)[offset];
- #pragma unroll
- for (int j = 0; j < ILP; ++j) {
- r[j] = tmpGradOutput * (std::exp(
- static_cast<accscalar_t>(v[j]) - coeff) -
- static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *
- smooth_positives - smooth_negatives);
- }
- reinterpret_cast<LoadT*>(gradInput)[offset] = *result;
- }
- offset = classes - last + threadIdx.x;
- for (; offset < classes; offset += blockDim.x)
- gradInput[offset] = tmpGradOutput * (std::exp(
- static_cast<accscalar_t>(logits[offset]) - coeff) -
- static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
- smooth_positives - smooth_negatives);
- }
- template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
- __global__ void
- cunn_SoftMaxXEntropyBackward(
- scalar_t *gradInput,
- scalar_t *logits,
- outscalar_t *max_log_sum_exp,
- outscalar_t *gradOutput,
- int64_t *labels,
- const float smoothing,
- int classes,
- const int total_classes)
- {
- gradInput += blockIdx.x * classes;
- logits += blockIdx.x * classes;
- // Do vectorized load/store when input/output have same alignment
- const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
- const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
- if (shift == shift_){
- aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
- }
- else {
- apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
- }
- }
- template<template<typename, typename, typename> class Epilogue>
- std::vector<Tensor> host_softmax_xentropy(
- const Tensor & input_,
- const Tensor & labels_,
- const float smoothing,
- const int total_classes) {
- // For tensor parallel cross entropy with smoothing, we want to pass in the total number
- // of classes so that smoothing can be applied correctly. If total_classes=-1, use the
- // last dimension of the input tensor.
- AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
- // Otherwise the kernel will be launched from cuda:0 device
- at::cuda::CUDAGuard device_guard{input_.device()};
- auto input = input_.contiguous();
- Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float));
- Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));
- static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
- std::is_same<acc_type<at::Half, true>, double>::value,
- "accscalar_t for half should be float or double");
- AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported");
- AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional");
- AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples");
- AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0");
- const int64_t dim = 1;
- int64_t outer_size = 1;
- int64_t dim_size = input.size(dim);
- int64_t inner_size = 1;
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- for (int64_t i = 0; i < dim; ++i)
- outer_size *= input.size(i);
- for (int64_t i = dim + 1; i < input.dim(); ++i)
- inner_size *= input.size(i);
- // This kernel spawns a block per each element in the batch.
- // XXX: it assumes that inner_size == 1
- TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
- dim3 grid(outer_size);
- using namespace at;
- DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy",
- using accscalar_t = at::acc_type<scalar_t_0, true>;
- const int ILP = sizeof(float4)/sizeof(scalar_t_0);
- dim3 block = SoftMax_getBlockSize(ILP, dim_size);
- cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
- <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
- losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
- input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
- dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes
- );
- );
- C10_CUDA_CHECK(cudaGetLastError());
- std::vector<at::Tensor> ret = {losses, max_log_sum_exp};
- return ret;
- }
- template<template<typename, typename, typename> class Epilogue>
- Tensor host_softmax_xentropy_backward(
- const at::Tensor &grad_loss,
- at::Tensor &logits_,
- const at::Tensor &max_log_sum_exp,
- const at::Tensor &labels,
- const float smoothing,
- bool inplace,
- const int total_classes) {
- // Otherwise the kernel will be launched from cuda:0 device
- at::cuda::CUDAGuard device_guard{grad_loss.device()};
- const int64_t dim = 1;
- Tensor gI = inplace ? logits_ : at::empty_like(logits_);
- if (grad_loss.numel() == 0) {
- return gI;
- }
- auto grad = grad_loss.contiguous();
- auto logits = logits_.contiguous();
- static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
- std::is_same<acc_type<at::Half, true>, double>::value,
- "accscalar_t for half should be float or double");
- if (grad.dim() == 0) grad = grad.view(1);
- AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported");
- AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional");
- AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0");
- AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples");
- AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples");
- int64_t outer_size = 1;
- int64_t dim_size = logits.size(dim);
- int64_t inner_size = 1;
- for (int64_t i = 0; i < dim; ++i)
- outer_size *= logits.size(i);
- for (int64_t i = dim + 1; i < logits.dim(); ++i)
- inner_size *= logits.size(i);
- // See descriptions of kernels above.
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
- dim3 grid(outer_size);
- DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
- using accscalar_t = acc_type<scalar_t_0, true>;
- const int ILP = sizeof(float4)/sizeof(scalar_t_0);
- dim3 block = SoftMax_getBlockSize(ILP, dim_size);
- cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
- <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
- gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
- max_log_sum_exp.data_ptr<accscalar_t>(),
- grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
- smoothing, dim_size, total_classes
- );
- );
- C10_CUDA_CHECK(cudaGetLastError());
- return gI;
- }
- std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){
- return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, total_classes);
- }
- at::Tensor softmax_xentropy_backward_cuda(
- const at::Tensor &grad_loss,
- at::Tensor &logits,
- const at::Tensor &max_log_sum_exp,
- const at::Tensor &labels,
- const float smoothing,
- const bool inplace,
- const int total_classes) {
- AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
- return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes);
- }
|