1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- /*
- * Modified by HandH1998
- * Modified by Neural Magic
- * Copyright (C) Marlin.2024 Elias Frantar
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #pragma once
- // Predicated asynchronous global->shared copy; used for inputs A where we apply
- // predication to handle batchsizes that are not multiples of 16.
- __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
- bool pred = true) {
- const int BYTES = 16;
- uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
- asm volatile(
- "{\n"
- " .reg .pred p;\n"
- " setp.ne.b32 p, %0, 0;\n"
- " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
- "}\n" ::"r"((int)pred),
- "r"(smem), "l"(glob_ptr), "n"(BYTES));
- }
- // Asynchronous global->shared copy
- __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
- const int BYTES = 16;
- uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
- asm volatile(
- "{\n"
- " cp.async.cg.shared.global [%0], [%1], %2;\n"
- "}\n" ::"r"(smem),
- "l"(glob_ptr), "n"(BYTES));
- }
- // Async copy fence.
- __device__ inline void cp_async_fence() {
- asm volatile("cp.async.commit_group;\n" ::);
- }
- // Wait until at most `n` async copy stages are still pending.
- template <int n>
- __device__ inline void cp_async_wait() {
- asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
- }
- // Wait until barrier reaches `count`, then lock for current threadblock.
- __device__ inline void barrier_acquire(int* lock, int count) {
- if (threadIdx.x == 0) {
- int state = -1;
- do
- // Guarantee that subsequent writes by this threadblock will be visible
- // globally.
- asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
- : "=r"(state)
- : "l"(lock));
- while (state != count);
- }
- __syncthreads();
- }
- // Release barrier and increment visitation count.
- __device__ inline void barrier_release(int* lock, bool reset = false) {
- __syncthreads();
- if (threadIdx.x == 0) {
- if (reset) {
- lock[0] = 0;
- return;
- }
- int val = 1;
- // Make sure that all writes since acquiring this barrier are visible
- // globally, while releasing the barrier.
- asm volatile("fence.acq_rel.gpu;\n");
- asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
- :
- : "l"(lock), "r"(val));
- }
- }
|