12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- // Copyright 2024 FP6-LLM authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- //
- // This file is copied from
- // https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_cp.async.cuh
- /***************************************************************************
- * Copyright 2023 The FLash-LLM Authors. All rights reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- * http://www.apache.org/licenses/LICENSE-2.0
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- ***************************************************************************/
- // Extended from CUTLASS's source code
- #ifndef PTX_CP_ASYNC_CUH
- #define PTX_CP_ASYNC_CUH
- #include <cuda.h>
- #include <cuda_fp16.h>
- #include <cuda_runtime.h>
- template <int SizeInBytes>
- __device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr,
- bool pred_guard = true) {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- static_assert(SizeInBytes == 16, "Size is not supported");
- unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr);
- asm volatile(
- "{ \n"
- " .reg .pred p;\n"
- " setp.ne.b32 p, %0, 0;\n"
- " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
- "}\n" ::"r"((int)pred_guard),
- "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
- #endif
- }
- /// Establishes an ordering w.r.t previously issued cp.async instructions. Does
- /// not block.
- __device__ __forceinline__ void cp_async_group_commit() {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- asm volatile("cp.async.commit_group;\n" ::);
- #endif
- }
- /// Blocks until all but <N> previous cp.async.commit_group operations have
- /// committed.
- template <int N>
- __device__ __forceinline__ void cp_async_wait_group() {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
- #endif
- }
- /// Blocks until all previous cp.async.commit_group operations have committed.
- // cp.async.wait_all is equivalent to :
- // cp.async.commit_group;
- // cp.async.wait_group 0;
- __device__ __forceinline__ void cp_async_wait_all() {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- asm volatile("cp.async.wait_all;\n" ::);
- #endif
- }
- #endif
|