cutlass_extensions_bf16.h 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #pragma once
  2. #include "cutlass/block_striped.h"
  3. namespace cutlass {
  4. /// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
  5. /// statically-sized array types to global memory.
  6. /// (Specialization for bfloat16_t. Uses nv_bfloat162 vectorized-reduction.)
  7. template <
  8. int BlockThreads,
  9. typename ArrayT>
  10. struct BlockStripedReduce<BlockThreads, ArrayT, bfloat16_t> :
  11. BlockStriped<
  12. BlockThreads,
  13. ArrayT,
  14. nv_bfloat162>
  15. {
  16. static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length");
  17. /// Reduce
  18. CUTLASS_DEVICE
  19. static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
  20. {
  21. // This operation is natively supported by devices of compute
  22. // capability 9.x and higher, older devices use emulation path
  23. cutlass::atomic_add<nv_bfloat162> reduce;
  24. nv_bfloat162 *access_output = reinterpret_cast<nv_bfloat162*>(ptr);
  25. const nv_bfloat162 *access_data = reinterpret_cast<const nv_bfloat162*>(&data);
  26. CUTLASS_PRAGMA_UNROLL
  27. for (int i = 0; i < BlockStripedReduce::kStripes; ++i)
  28. {
  29. reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
  30. }
  31. }
  32. };
  33. } // namespace cutlass