common.hpp 807 B

1234567891011121314151617181920212223242526
  1. #pragma once
  2. #include "cutlass/cutlass.h"
  3. #include <climits>
  4. /**
  5. * Helper function for checking CUTLASS errors
  6. */
  7. #define CUTLASS_CHECK(status) \
  8. { \
  9. TORCH_CHECK(status == cutlass::Status::kSuccess, \
  10. cutlassGetStatusString(status)) \
  11. }
  12. inline uint32_t next_pow_2(uint32_t const num) {
  13. if (num <= 1) return num;
  14. return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
  15. }
  16. inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
  17. int max_shared_mem_per_block_opt_in = 0;
  18. cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
  19. cudaDevAttrMaxSharedMemoryPerBlockOptin,
  20. device);
  21. return max_shared_mem_per_block_opt_in;
  22. }