ptx_cp.async.cuh 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. // Copyright 2024 FP6-LLM authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // This file is copied from
  16. // https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_cp.async.cuh
  17. /***************************************************************************
  18. * Copyright 2023 The FLash-LLM Authors. All rights reserved.
  19. * Licensed under the Apache License, Version 2.0 (the "License");
  20. * you may not use this file except in compliance with the License.
  21. * You may obtain a copy of the License at
  22. * http://www.apache.org/licenses/LICENSE-2.0
  23. * Unless required by applicable law or agreed to in writing, software
  24. * distributed under the License is distributed on an "AS IS" BASIS,
  25. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. * See the License for the specific language governing permissions and
  27. * limitations under the License.
  28. ***************************************************************************/
  29. // Extended from CUTLASS's source code
  30. #ifndef PTX_CP_ASYNC_CUH
  31. #define PTX_CP_ASYNC_CUH
  32. #include <cuda.h>
  33. #include <cuda_fp16.h>
  34. #include <cuda_runtime.h>
  35. template <int SizeInBytes>
  36. __device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr,
  37. bool pred_guard = true) {
  38. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  39. static_assert(SizeInBytes == 16, "Size is not supported");
  40. unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr);
  41. asm volatile(
  42. "{ \n"
  43. " .reg .pred p;\n"
  44. " setp.ne.b32 p, %0, 0;\n"
  45. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  46. "}\n" ::"r"((int)pred_guard),
  47. "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
  48. #endif
  49. }
  50. /// Establishes an ordering w.r.t previously issued cp.async instructions. Does
  51. /// not block.
  52. __device__ __forceinline__ void cp_async_group_commit() {
  53. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  54. asm volatile("cp.async.commit_group;\n" ::);
  55. #endif
  56. }
  57. /// Blocks until all but <N> previous cp.async.commit_group operations have
  58. /// committed.
  59. template <int N>
  60. __device__ __forceinline__ void cp_async_wait_group() {
  61. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  62. asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
  63. #endif
  64. }
  65. /// Blocks until all previous cp.async.commit_group operations have committed.
  66. // cp.async.wait_all is equivalent to :
  67. // cp.async.commit_group;
  68. // cp.async.wait_group 0;
  69. __device__ __forceinline__ void cp_async_wait_all() {
  70. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  71. asm volatile("cp.async.wait_all;\n" ::);
  72. #endif
  73. }
  74. #endif