xentropy_kernel.cu 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. // Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu
  2. // TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory).
  3. /**
  4. * From PyTorch:
  5. *
  6. * Copyright (c) 2016- Facebook, Inc (Adam Paszke)
  7. * Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
  8. * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
  9. * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
  10. * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
  11. * Copyright (c) 2011-2013 NYU (Clement Farabet)
  12. * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
  13. * Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
  14. * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
  15. *
  16. * From Caffe2:
  17. *
  18. * Copyright (c) 2016-present, Facebook Inc. All rights reserved.
  19. *
  20. * All contributions by Facebook:
  21. * Copyright (c) 2016 Facebook Inc.
  22. *
  23. * All contributions by Google:
  24. * Copyright (c) 2015 Google Inc.
  25. * All rights reserved.
  26. *
  27. * All contributions by Yangqing Jia:
  28. * Copyright (c) 2015 Yangqing Jia
  29. * All rights reserved.
  30. *
  31. * All contributions from Caffe:
  32. * Copyright(c) 2013, 2014, 2015, the respective contributors
  33. * All rights reserved.
  34. *
  35. * All other contributions:
  36. * Copyright(c) 2015, 2016 the respective contributors
  37. * All rights reserved.
  38. *
  39. * Caffe2 uses a copyright model similar to Caffe: each contributor holds
  40. * copyright over their contributions to Caffe2. The project versioning records
  41. * all such contribution and copyright details. If a contributor wants to further
  42. * mark their specific copyright on a particular contribution, they should
  43. * indicate their copyright solely in the commit message of the change when it is
  44. * committed.
  45. *
  46. * All rights reserved.
  47. *
  48. * Redistribution and use in source and binary forms, with or without
  49. * modification, are permitted provided that the following conditions are met:
  50. *
  51. * 1. Redistributions of source code must retain the above copyright
  52. * notice, this list of conditions and the following disclaimer.
  53. *
  54. * 2. Redistributions in binary form must reproduce the above copyright
  55. * notice, this list of conditions and the following disclaimer in the
  56. * documentation and/or other materials provided with the distribution.
  57. *
  58. * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
  59. * and IDIAP Research Institute nor the names of its contributors may be
  60. * used to endorse or promote products derived from this software without
  61. * specific prior written permission.
  62. *
  63. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  64. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  65. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  66. * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  67. * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  68. * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  69. * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  70. * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  71. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  72. * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  73. * POSSIBILITY OF SUCH DAMAGE.
  74. */
  75. #include <ATen/ATen.h>
  76. #include <ATen/cuda/CUDAContext.h>
  77. #include <c10/cuda/CUDAGuard.h>
  78. #include <ATen/AccumulateType.h>
  79. #include <ATen/cuda/NumericLimits.cuh>
  80. // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
  81. // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  82. #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
  83. switch(TYPE) \
  84. { \
  85. case at::ScalarType::Float: \
  86. { \
  87. using scalar_t_##LEVEL = float; \
  88. __VA_ARGS__; \
  89. break; \
  90. } \
  91. case at::ScalarType::Half: \
  92. { \
  93. using scalar_t_##LEVEL = at::Half; \
  94. __VA_ARGS__; \
  95. break; \
  96. } \
  97. case at::ScalarType::BFloat16: \
  98. { \
  99. using scalar_t_##LEVEL = at::BFloat16; \
  100. __VA_ARGS__; \
  101. break; \
  102. } \
  103. default: \
  104. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  105. }
  106. // #else
  107. // #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
  108. // switch(TYPE) \
  109. // { \
  110. // case at::ScalarType::Float: \
  111. // { \
  112. // using scalar_t_##LEVEL = float; \
  113. // __VA_ARGS__; \
  114. // break; \
  115. // } \
  116. // case at::ScalarType::Half: \
  117. // { \
  118. // using scalar_t_##LEVEL = at::Half; \
  119. // __VA_ARGS__; \
  120. // break; \
  121. // } \
  122. // default: \
  123. // AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  124. // }
  125. // #endif
  126. #define ALIGN_BYTES 16
  127. using Tensor = at::Tensor;
  128. using TensorList = at::TensorList;
  129. using ScalarType = at::ScalarType;
  130. using at::acc_type;
  131. template<typename T, typename AccumT, typename OutT>
  132. struct LogSoftMaxForwardEpilogue {
  133. __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
  134. : logsum(max_input + std::log(sum)) {}
  135. __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)
  136. : logsum(max_log_sum_exp) {}
  137. __device__ __forceinline__ OutT operator()(T input) const {
  138. return static_cast<OutT>(input - logsum);
  139. }
  140. const AccumT logsum;
  141. };
  142. template<typename T, typename AccumT, typename OutT>
  143. struct LogSoftMaxBackwardEpilogue {
  144. __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
  145. : sum(sum) {}
  146. __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
  147. return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
  148. }
  149. const AccumT sum;
  150. };
  151. const int max_threads = 1024;
  152. inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
  153. uint64_t block_size = 1;
  154. uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
  155. while (block_size < (max_block_size/2)) block_size *= 2;
  156. // Launch at least a single warp - the kernel assumes that.
  157. block_size = std::max(block_size, static_cast<uint64_t>(32));
  158. return dim3(block_size);
  159. }
  160. template<typename T>
  161. struct Add {
  162. __device__ __forceinline__ T operator()(T a, T b) const {
  163. return a + b;
  164. }
  165. };
  166. template<typename T>
  167. struct Max {
  168. __device__ __forceinline__ T operator()(T a, T b) const {
  169. return a < b ? b : a;
  170. }
  171. };
  172. ////////////////////////////////////////////////////////////////////////////////
  173. // Regular kernel (fast when dim_size is large; requires inner_size == 1)
  174. ////////////////////////////////////////////////////////////////////////////////
  175. template <typename T, typename AccumT>
  176. struct MaxFloat
  177. {
  178. __device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
  179. return ::max(max, (AccumT)v);
  180. }
  181. };
  182. template<typename T, typename AccumT>
  183. struct AddFloat
  184. {
  185. __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
  186. return sum + v;
  187. }
  188. };
  189. template<typename T, typename AccumT>
  190. struct SumExpFloat
  191. {
  192. __device__ __forceinline__ SumExpFloat(AccumT v)
  193. : max_k(v) {}
  194. __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
  195. return sum + std::exp(v - max_k);
  196. }
  197. const AccumT max_k;
  198. };
  199. template <template<typename> class Reduction, typename AccumT>
  200. __device__ __forceinline__ AccumT
  201. blockReduce(AccumT* smem, AccumT val,
  202. const Reduction<AccumT>& r,
  203. AccumT defaultVal)
  204. {
  205. // To avoid RaW races from chaining blockReduce calls together, we need a sync here
  206. __syncthreads();
  207. smem[threadIdx.x] = val;
  208. __syncthreads();
  209. AccumT warpVal = defaultVal;
  210. // First warp will perform per-warp reductions for the remaining warps
  211. uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
  212. if (threadIdx.x < 32) {
  213. int lane = threadIdx.x % 32;
  214. if (lane < blockDim.x / 32) {
  215. #pragma unroll
  216. for (int i = 0; i < 32; ++i) {
  217. warpVal = r(warpVal, smem[lane * 32 + i]);
  218. }
  219. __syncwarp(mask);
  220. smem[lane] = warpVal;
  221. }
  222. }
  223. __syncthreads();
  224. // First thread will perform a reduction of the above per-warp reductions
  225. AccumT blockVal = defaultVal;
  226. if (threadIdx.x == 0) {
  227. for (int i = 0; i < blockDim.x / 32; ++i) {
  228. blockVal = r(blockVal, smem[i]);
  229. }
  230. smem[0] = blockVal;
  231. }
  232. // Sync and broadcast
  233. __syncthreads();
  234. return smem[0];
  235. }
  236. template <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>
  237. __device__ __forceinline__ void
  238. blockReduce(AccumT* smem,
  239. AccumT* reducVal1,
  240. AccumT val1,
  241. const Reduction1<AccumT>& r1,
  242. AccumT defaultVal1,
  243. AccumT* reducVal2,
  244. AccumT val2,
  245. const Reduction2<AccumT>& r2,
  246. AccumT defaultVal2)
  247. {
  248. // To avoid RaW races from chaining blockReduce calls together, we need a sync here
  249. __syncthreads();
  250. smem[threadIdx.x] = val1;
  251. smem[blockDim.x + threadIdx.x] = val2;
  252. __syncthreads();
  253. AccumT warpVal1 = defaultVal1;
  254. AccumT warpVal2 = defaultVal2;
  255. // First warp will perform per-warp reductions for the remaining warps
  256. uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
  257. if (threadIdx.x < 32) {
  258. int lane = threadIdx.x % 32;
  259. if (lane < blockDim.x / 32) {
  260. #pragma unroll
  261. for (int i = 0; i < 32; ++i) {
  262. warpVal1 = r1(warpVal1, smem[lane * 32 + i]);
  263. warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);
  264. }
  265. __syncwarp(mask);
  266. smem[lane] = warpVal1;
  267. smem[lane + blockDim.x] = warpVal2;
  268. }
  269. }
  270. __syncthreads();
  271. // First thread will perform a reduction of the above per-warp reductions
  272. AccumT blockVal1 = defaultVal1;
  273. AccumT blockVal2 = defaultVal2;
  274. if (threadIdx.x == 0) {
  275. for (int i = 0; i < blockDim.x / 32; ++i) {
  276. blockVal1 = r1(blockVal1, smem[i]);
  277. blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
  278. }
  279. smem[0] = blockVal1;
  280. smem[blockDim.x] = blockVal2;
  281. }
  282. // Sync and broadcast
  283. __syncthreads();
  284. *reducVal1 = smem[0];
  285. *reducVal2 = smem[blockDim.x];
  286. __syncthreads();
  287. }
  288. template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
  289. __device__ __forceinline__ AccumT
  290. ilpReduce(int shift,
  291. T* data,
  292. int size,
  293. const Reduction<T, AccumT>& r,
  294. AccumT defaultVal)
  295. {
  296. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
  297. AccumT threadVal = defaultVal;
  298. int offset = threadIdx.x;
  299. // shift and do 1
  300. if(shift > 0){
  301. data -= shift;
  302. size += shift;
  303. if(threadIdx.x >= shift){
  304. threadVal = r(threadVal, data[offset]);
  305. }
  306. size -= blockDim.x;
  307. data += blockDim.x;
  308. }
  309. int last = size % (ILP * blockDim.x);
  310. T v[ILP];
  311. LoadT* value = reinterpret_cast<LoadT*>(&v);
  312. for (; offset * ILP < (size - last); offset += blockDim.x) {
  313. *value = reinterpret_cast<LoadT*>(data)[offset];
  314. for (int j = 0; j < ILP; ++j) {
  315. threadVal = r(threadVal, v[j]);
  316. }
  317. }
  318. offset = size - last + threadIdx.x;
  319. // Epilogue
  320. for (; offset < size; offset += blockDim.x)
  321. threadVal = r(threadVal, data[offset]);
  322. return threadVal;
  323. }
  324. template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
  325. __device__ __forceinline__ void
  326. ilpReduce(int shift,
  327. T* data,
  328. int size,
  329. AccumT* reducVal1,
  330. const Reduction1<T, AccumT>& r1,
  331. AccumT defaultVal1,
  332. AccumT* reducVal2,
  333. const Reduction2<T, AccumT>& r2,
  334. AccumT defaultVal2)
  335. {
  336. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
  337. AccumT threadVal1 = defaultVal1;
  338. AccumT threadVal2 = defaultVal2;
  339. int offset = threadIdx.x;
  340. // shift and do 1
  341. if(shift > 0){
  342. data -= shift;
  343. size += shift;
  344. if(threadIdx.x >= shift){
  345. threadVal1 = r1(threadVal1, data[offset]);
  346. threadVal2 = r2(threadVal2, data[offset]);
  347. }
  348. size -= blockDim.x;
  349. data += blockDim.x;
  350. }
  351. int last = size % (ILP * blockDim.x);
  352. T v[ILP];
  353. LoadT* value = reinterpret_cast<LoadT*>(&v);
  354. for (; offset * ILP < (size - last); offset += blockDim.x) {
  355. *value = reinterpret_cast<LoadT*>(data)[offset];
  356. for (int j = 0; j < ILP; ++j) {
  357. threadVal1 = r1(threadVal1, v[j]);
  358. threadVal2 = r2(threadVal2, v[j]);
  359. }
  360. }
  361. offset = size - last + threadIdx.x;
  362. // Epilogue
  363. for (; offset < size; offset += blockDim.x) {
  364. threadVal1 = r1(threadVal1, data[offset]);
  365. threadVal2 = r2(threadVal2, data[offset]);
  366. }
  367. *reducVal1 = threadVal1;
  368. *reducVal2 = threadVal2;
  369. }
  370. template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
  371. __global__ void
  372. cunn_SoftMaxXEntropyForward(
  373. accscalar_t *losses,
  374. outscalar_t *max_log_sum_exp,
  375. scalar_t *input,
  376. int64_t *labels,
  377. int64_t classes,
  378. const float smoothing,
  379. const int total_classes)
  380. {
  381. extern __shared__ unsigned char smem[];
  382. auto sdata = reinterpret_cast<accscalar_t*>(smem);
  383. // forward pointers to batch[blockIdx.x]
  384. // each block handles a sample in the mini-batch
  385. input += blockIdx.x * classes;
  386. //output += blockIdx.x * classes;
  387. const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
  388. int64_t label = labels[blockIdx.x];
  389. // find the max and sum
  390. accscalar_t threadMax, threadSum, max_k, sum_k;
  391. ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
  392. shift, input, classes,
  393. &threadMax, MaxFloat<scalar_t, accscalar_t>(),
  394. -at::numeric_limits<accscalar_t>::max(),
  395. &threadSum, AddFloat<scalar_t, accscalar_t>(),
  396. static_cast<accscalar_t>(0));
  397. blockReduce<Max, Add, accscalar_t>(
  398. sdata,
  399. &max_k, threadMax, Max<accscalar_t>(),
  400. -at::numeric_limits<accscalar_t>::max(),
  401. &sum_k, threadSum, Add<accscalar_t>(),
  402. static_cast<accscalar_t>(0));
  403. accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
  404. accscalar_t sumAll = blockReduce<Add, accscalar_t>(
  405. sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
  406. Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
  407. // calculate per element loss with label smoothing
  408. // reserve max + log_sum_exp for bprop
  409. if (threadIdx.x == 0) {
  410. accscalar_t lse = max_k + std::log(sumAll);
  411. accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast<accscalar_t>(input[label])) : 0.f;
  412. losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing);
  413. max_log_sum_exp[blockIdx.x] = lse;
  414. }
  415. }
  416. template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
  417. __device__ __forceinline__ void
  418. apply(scalar_t *gradInput,
  419. scalar_t *logits,
  420. outscalar_t *max_log_sum_exp,
  421. outscalar_t *gradOutput,
  422. int64_t *labels,
  423. const float smoothing,
  424. int classes,
  425. const int total_classes)
  426. {
  427. accscalar_t smooth_positives = 1.0 - smoothing;
  428. accscalar_t smooth_negatives = smoothing / total_classes;
  429. accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
  430. int64_t label = labels[blockIdx.x];
  431. accscalar_t coeff = max_log_sum_exp[blockIdx.x];
  432. int offset = threadIdx.x;
  433. int last = classes % (ILP * blockDim.x);
  434. for (; offset < classes - last; offset += blockDim.x * ILP) {
  435. accscalar_t tmpLogits[ILP];
  436. #pragma unroll
  437. for (int j = 0; j < ILP; ++j) {
  438. tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);
  439. }
  440. #pragma unroll
  441. for (int j = 0; j < ILP; ++j)
  442. gradInput[offset + j * blockDim.x] = tmpGradOutput * (
  443. std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
  444. (offset + j * blockDim.x == label) ? 1 : 0) *
  445. smooth_positives - smooth_negatives);
  446. }
  447. for (; offset < classes; offset += blockDim.x)
  448. gradInput[offset] = tmpGradOutput * (std::exp(
  449. static_cast<accscalar_t>(logits[offset]) - coeff) -
  450. static_cast<accscalar_t>((offset == label) ? 1 : 0) *
  451. smooth_positives - smooth_negatives);
  452. }
  453. template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
  454. __device__ __forceinline__ void
  455. aligned_apply(int shift,
  456. scalar_t *gradInput,
  457. scalar_t *logits,
  458. outscalar_t *max_log_sum_exp,
  459. outscalar_t *gradOutput,
  460. int64_t *labels,
  461. const float smoothing,
  462. int classes,
  463. const int total_classes)
  464. {
  465. accscalar_t smooth_positives = 1.0 - smoothing;
  466. accscalar_t smooth_negatives = smoothing / total_classes;
  467. accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
  468. int64_t label = labels[blockIdx.x];
  469. accscalar_t coeff = max_log_sum_exp[blockIdx.x];
  470. int offset = threadIdx.x;
  471. // shift and do 1
  472. if(shift > 0){
  473. logits -= shift;
  474. gradInput -= shift;
  475. classes += shift;
  476. if(threadIdx.x >= shift){
  477. gradInput[offset] = tmpGradOutput * (std::exp(
  478. static_cast<accscalar_t>(logits[offset]) - coeff) -
  479. static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
  480. smooth_positives - smooth_negatives);
  481. }
  482. classes -= blockDim.x;
  483. gradInput += blockDim.x;
  484. logits += blockDim.x;
  485. shift -= blockDim.x;
  486. }
  487. int last = classes % (ILP * blockDim.x);
  488. typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;
  489. // input
  490. scalar_t v[ILP];
  491. LoadT* value = reinterpret_cast<LoadT*>(&v);
  492. // output
  493. scalar_t r[ILP];
  494. LoadT* result = reinterpret_cast<LoadT*>(&r);
  495. for (; offset * ILP < (classes - last); offset += blockDim.x) {
  496. *value = reinterpret_cast<LoadT*>(logits)[offset];
  497. #pragma unroll
  498. for (int j = 0; j < ILP; ++j) {
  499. r[j] = tmpGradOutput * (std::exp(
  500. static_cast<accscalar_t>(v[j]) - coeff) -
  501. static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *
  502. smooth_positives - smooth_negatives);
  503. }
  504. reinterpret_cast<LoadT*>(gradInput)[offset] = *result;
  505. }
  506. offset = classes - last + threadIdx.x;
  507. for (; offset < classes; offset += blockDim.x)
  508. gradInput[offset] = tmpGradOutput * (std::exp(
  509. static_cast<accscalar_t>(logits[offset]) - coeff) -
  510. static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
  511. smooth_positives - smooth_negatives);
  512. }
  513. template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
  514. __global__ void
  515. cunn_SoftMaxXEntropyBackward(
  516. scalar_t *gradInput,
  517. scalar_t *logits,
  518. outscalar_t *max_log_sum_exp,
  519. outscalar_t *gradOutput,
  520. int64_t *labels,
  521. const float smoothing,
  522. int classes,
  523. const int total_classes)
  524. {
  525. gradInput += blockIdx.x * classes;
  526. logits += blockIdx.x * classes;
  527. // Do vectorized load/store when input/output have same alignment
  528. const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
  529. const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
  530. if (shift == shift_){
  531. aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
  532. }
  533. else {
  534. apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
  535. }
  536. }
  537. template<template<typename, typename, typename> class Epilogue>
  538. std::vector<Tensor> host_softmax_xentropy(
  539. const Tensor & input_,
  540. const Tensor & labels_,
  541. const float smoothing,
  542. const int total_classes) {
  543. // For tensor parallel cross entropy with smoothing, we want to pass in the total number
  544. // of classes so that smoothing can be applied correctly. If total_classes=-1, use the
  545. // last dimension of the input tensor.
  546. AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
  547. // Otherwise the kernel will be launched from cuda:0 device
  548. at::cuda::CUDAGuard device_guard{input_.device()};
  549. auto input = input_.contiguous();
  550. Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float));
  551. Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));
  552. static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
  553. std::is_same<acc_type<at::Half, true>, double>::value,
  554. "accscalar_t for half should be float or double");
  555. AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported");
  556. AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional");
  557. AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples");
  558. AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0");
  559. const int64_t dim = 1;
  560. int64_t outer_size = 1;
  561. int64_t dim_size = input.size(dim);
  562. int64_t inner_size = 1;
  563. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  564. for (int64_t i = 0; i < dim; ++i)
  565. outer_size *= input.size(i);
  566. for (int64_t i = dim + 1; i < input.dim(); ++i)
  567. inner_size *= input.size(i);
  568. // This kernel spawns a block per each element in the batch.
  569. // XXX: it assumes that inner_size == 1
  570. TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
  571. dim3 grid(outer_size);
  572. using namespace at;
  573. DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy",
  574. using accscalar_t = at::acc_type<scalar_t_0, true>;
  575. const int ILP = sizeof(float4)/sizeof(scalar_t_0);
  576. dim3 block = SoftMax_getBlockSize(ILP, dim_size);
  577. cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
  578. <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
  579. losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
  580. input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
  581. dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes
  582. );
  583. );
  584. C10_CUDA_CHECK(cudaGetLastError());
  585. std::vector<at::Tensor> ret = {losses, max_log_sum_exp};
  586. return ret;
  587. }
  588. template<template<typename, typename, typename> class Epilogue>
  589. Tensor host_softmax_xentropy_backward(
  590. const at::Tensor &grad_loss,
  591. at::Tensor &logits_,
  592. const at::Tensor &max_log_sum_exp,
  593. const at::Tensor &labels,
  594. const float smoothing,
  595. bool inplace,
  596. const int total_classes) {
  597. // Otherwise the kernel will be launched from cuda:0 device
  598. at::cuda::CUDAGuard device_guard{grad_loss.device()};
  599. const int64_t dim = 1;
  600. Tensor gI = inplace ? logits_ : at::empty_like(logits_);
  601. if (grad_loss.numel() == 0) {
  602. return gI;
  603. }
  604. auto grad = grad_loss.contiguous();
  605. auto logits = logits_.contiguous();
  606. static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
  607. std::is_same<acc_type<at::Half, true>, double>::value,
  608. "accscalar_t for half should be float or double");
  609. if (grad.dim() == 0) grad = grad.view(1);
  610. AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported");
  611. AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional");
  612. AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0");
  613. AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples");
  614. AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples");
  615. int64_t outer_size = 1;
  616. int64_t dim_size = logits.size(dim);
  617. int64_t inner_size = 1;
  618. for (int64_t i = 0; i < dim; ++i)
  619. outer_size *= logits.size(i);
  620. for (int64_t i = dim + 1; i < logits.dim(); ++i)
  621. inner_size *= logits.size(i);
  622. // See descriptions of kernels above.
  623. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  624. TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
  625. dim3 grid(outer_size);
  626. DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
  627. using accscalar_t = acc_type<scalar_t_0, true>;
  628. const int ILP = sizeof(float4)/sizeof(scalar_t_0);
  629. dim3 block = SoftMax_getBlockSize(ILP, dim_size);
  630. cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
  631. <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
  632. gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
  633. max_log_sum_exp.data_ptr<accscalar_t>(),
  634. grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
  635. smoothing, dim_size, total_classes
  636. );
  637. );
  638. C10_CUDA_CHECK(cudaGetLastError());
  639. return gI;
  640. }
  641. std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){
  642. return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, total_classes);
  643. }
  644. at::Tensor softmax_xentropy_backward_cuda(
  645. const at::Tensor &grad_loss,
  646. at::Tensor &logits,
  647. const at::Tensor &max_log_sum_exp,
  648. const at::Tensor &labels,
  649. const float smoothing,
  650. const bool inplace,
  651. const int total_classes) {
  652. AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
  653. return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes);
  654. }