|
@@ -434,7 +434,8 @@ cunn_SoftMaxXEntropyForward(
|
|
|
scalar_t *input,
|
|
|
int64_t *labels,
|
|
|
int64_t classes,
|
|
|
- const float smoothing)
|
|
|
+ const float smoothing,
|
|
|
+ const int total_classes)
|
|
|
{
|
|
|
extern __shared__ unsigned char smem[];
|
|
|
auto sdata = reinterpret_cast<accscalar_t*>(smem);
|
|
@@ -472,12 +473,8 @@ cunn_SoftMaxXEntropyForward(
|
|
|
// reserve max + log_sum_exp for bprop
|
|
|
if (threadIdx.x == 0) {
|
|
|
accscalar_t lse = max_k + std::log(sumAll);
|
|
|
- if ((label >= 0) && (label < classes)) {
|
|
|
- accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));
|
|
|
- losses[blockIdx.x] = (lse - sum_k / classes) * smoothing - log_prob * (1 - smoothing);
|
|
|
- } else {
|
|
|
- losses[blockIdx.x] = outscalar_t(0.f);
|
|
|
- }
|
|
|
+ accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast<accscalar_t>(input[label])) : 0.f;
|
|
|
+ losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing);
|
|
|
max_log_sum_exp[blockIdx.x] = lse;
|
|
|
}
|
|
|
}
|
|
@@ -490,10 +487,11 @@ apply(scalar_t *gradInput,
|
|
|
outscalar_t *gradOutput,
|
|
|
int64_t *labels,
|
|
|
const float smoothing,
|
|
|
- int classes)
|
|
|
+ int classes,
|
|
|
+ const int total_classes)
|
|
|
{
|
|
|
accscalar_t smooth_positives = 1.0 - smoothing;
|
|
|
- accscalar_t smooth_negatives = smoothing / classes;
|
|
|
+ accscalar_t smooth_negatives = smoothing / total_classes;
|
|
|
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
|
|
|
int64_t label = labels[blockIdx.x];
|
|
|
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
|
|
@@ -534,10 +532,11 @@ aligned_apply(int shift,
|
|
|
outscalar_t *gradOutput,
|
|
|
int64_t *labels,
|
|
|
const float smoothing,
|
|
|
- int classes)
|
|
|
+ int classes,
|
|
|
+ const int total_classes)
|
|
|
{
|
|
|
accscalar_t smooth_positives = 1.0 - smoothing;
|
|
|
- accscalar_t smooth_negatives = smoothing / classes;
|
|
|
+ accscalar_t smooth_negatives = smoothing / total_classes;
|
|
|
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
|
|
|
int64_t label = labels[blockIdx.x];
|
|
|
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
|
|
@@ -602,7 +601,8 @@ cunn_SoftMaxXEntropyBackward(
|
|
|
outscalar_t *gradOutput,
|
|
|
int64_t *labels,
|
|
|
const float smoothing,
|
|
|
- int classes)
|
|
|
+ int classes,
|
|
|
+ const int total_classes)
|
|
|
{
|
|
|
gradInput += blockIdx.x * classes;
|
|
|
logits += blockIdx.x * classes;
|
|
@@ -611,10 +611,10 @@ cunn_SoftMaxXEntropyBackward(
|
|
|
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
|
|
|
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
|
|
|
if (shift == shift_){
|
|
|
- aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
|
|
|
+ aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
|
|
|
}
|
|
|
else {
|
|
|
- apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
|
|
|
+ apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
|
|
|
}
|
|
|
|
|
|
}
|
|
@@ -623,7 +623,11 @@ template<template<typename, typename, typename> class Epilogue>
|
|
|
std::vector<Tensor> host_softmax_xentropy(
|
|
|
const Tensor & input_,
|
|
|
const Tensor & labels_,
|
|
|
- const float smoothing){
|
|
|
+ const float smoothing,
|
|
|
+ const int total_classes) {
|
|
|
+ // For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
|
|
+ // of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
|
|
+ // last dimension of the input tensor.
|
|
|
AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
|
|
|
|
|
|
// Otherwise the kernel will be launched from cuda:0 device
|
|
@@ -666,7 +670,7 @@ std::vector<Tensor> host_softmax_xentropy(
|
|
|
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
|
|
|
losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
|
|
|
input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
|
|
|
- dim_size, smoothing
|
|
|
+ dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes
|
|
|
);
|
|
|
);
|
|
|
|
|
@@ -683,7 +687,8 @@ Tensor host_softmax_xentropy_backward(
|
|
|
const at::Tensor &max_log_sum_exp,
|
|
|
const at::Tensor &labels,
|
|
|
const float smoothing,
|
|
|
- bool inplace) {
|
|
|
+ bool inplace,
|
|
|
+ const int total_classes) {
|
|
|
// Otherwise the kernel will be launched from cuda:0 device
|
|
|
// Cast to char to avoid compiler warning about narrowing
|
|
|
at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()};
|
|
@@ -730,7 +735,7 @@ Tensor host_softmax_xentropy_backward(
|
|
|
gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
|
|
|
max_log_sum_exp.data_ptr<accscalar_t>(),
|
|
|
grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
|
|
|
- smoothing, dim_size
|
|
|
+ smoothing, dim_size, total_classes
|
|
|
);
|
|
|
);
|
|
|
|
|
@@ -738,8 +743,8 @@ Tensor host_softmax_xentropy_backward(
|
|
|
return gI;
|
|
|
}
|
|
|
|
|
|
-std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing){
|
|
|
- return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing);
|
|
|
+std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){
|
|
|
+ return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, total_classes);
|
|
|
}
|
|
|
|
|
|
at::Tensor softmax_xentropy_backward_cuda(
|
|
@@ -748,7 +753,8 @@ at::Tensor softmax_xentropy_backward_cuda(
|
|
|
const at::Tensor &max_log_sum_exp,
|
|
|
const at::Tensor &labels,
|
|
|
const float smoothing,
|
|
|
- const bool inplace) {
|
|
|
+ const bool inplace,
|
|
|
+ const int total_classes) {
|
|
|
AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
|
|
|
- return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
|
|
|
+ return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes);
|
|
|
}
|