@@ -434,7 +434,8 @@ cunn_SoftMaxXEntropyForward(
scalar_t *input,
int64_t *labels,
int64_t classes,
- const float smoothing)
+ const float smoothing,
+ const int total_classes)
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
@@ -472,12 +473,8 @@ cunn_SoftMaxXEntropyForward(
// reserve max + log_sum_exp for bprop
if (threadIdx.x == 0) {
accscalar_t lse = max_k + std::log(sumAll);
- if ((label >= 0) && (label < classes)) {
- accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));
- losses[blockIdx.x] = (lse - sum_k / classes) * smoothing - log_prob * (1 - smoothing);
- } else {
- losses[blockIdx.x] = outscalar_t(0.f);
- }
+ accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast<accscalar_t>(input[label])) : 0.f;
+ losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing);
max_log_sum_exp[blockIdx.x] = lse;
@@ -490,10 +487,11 @@ apply(scalar_t *gradInput,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
- int classes)
+ int classes,
+ const int total_classes)
accscalar_t smooth_positives = 1.0 - smoothing;
- accscalar_t smooth_negatives = smoothing / classes;
+ accscalar_t smooth_negatives = smoothing / total_classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
@@ -534,10 +532,11 @@ aligned_apply(int shift,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
- int classes)
+ int classes,
+ const int total_classes)
accscalar_t smooth_positives = 1.0 - smoothing;
- accscalar_t smooth_negatives = smoothing / classes;
+ accscalar_t smooth_negatives = smoothing / total_classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
@@ -602,7 +601,8 @@ cunn_SoftMaxXEntropyBackward(
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
- int classes)
+ int classes,
+ const int total_classes)
gradInput += blockIdx.x * classes;
logits += blockIdx.x * classes;
@@ -611,10 +611,10 @@ cunn_SoftMaxXEntropyBackward(
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
if (shift == shift_){
- aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
+ aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
else {
- apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
+ apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
@@ -623,7 +623,11 @@ template<template<typename, typename, typename> class Epilogue>
std::vector<Tensor> host_softmax_xentropy(
const Tensor & input_,
const Tensor & labels_,
- const float smoothing){
+ const float smoothing,
+ const int total_classes) {
+ // For tensor parallel cross entropy with smoothing, we want to pass in the total number
+ // of classes so that smoothing can be applied correctly. If total_classes=-1, use the
+ // last dimension of the input tensor.
AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
// Otherwise the kernel will be launched from cuda:0 device
@@ -666,7 +670,7 @@ std::vector<Tensor> host_softmax_xentropy(
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
- dim_size, smoothing
+ dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes
@@ -683,7 +687,8 @@ Tensor host_softmax_xentropy_backward(
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing,
- bool inplace) {
+ bool inplace,
+ const int total_classes) {
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()};
@@ -730,7 +735,7 @@ Tensor host_softmax_xentropy_backward(
gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
- smoothing, dim_size
+ smoothing, dim_size, total_classes
@@ -738,8 +743,8 @@ Tensor host_softmax_xentropy_backward(
return gI;
-std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing){
- return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing);
+std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){
+ return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, total_classes);
at::Tensor softmax_xentropy_backward_cuda(
@@ -748,7 +753,8 @@ at::Tensor softmax_xentropy_backward_cuda(
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing,
- const bool inplace) {
+ const bool inplace,
+ const int total_classes) {
AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
- return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
+ return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes);