Bläddra i källkod

fix: illegal mem access crash for marlin

AlpinDale 8 månader sedan
förälder
incheckning
1225c4dfd6
1 ändrade filer med 10 tillägg och 12 borttagningar
  1. 10 12
      kernels/quantization/marlin/marlin_cuda_kernel.cu

+ 10 - 12
kernels/quantization/marlin/marlin_cuda_kernel.cu

@@ -71,17 +71,15 @@
  // may be evicted immediately; used for quantized weights B, which are only
  // accessed precisely once and should thus not pollute the L2 cache which we
  // need for inputs A and outputs C.
- __device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) {
+
+ // Async global->shared copy
+ __device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
    const int BYTES = 16;
    uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
-   asm volatile(
-       "{\n"
-       "   .reg .b64 p;\n"
-       "   createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
-       "   cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
-       "}\n" ::"r"(smem),
-       "l"(glob_ptr), "n"(BYTES));
- }
+      asm volatile("{\n"
+                   "   cp.async.cg.shared.global [%0], [%1], %2;\n"
+                   "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES));
+  }
  
  // Async copy fence.
  __device__ inline void cp_async_fence() {
@@ -448,14 +446,14 @@
        int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
  #pragma unroll
        for (int i = 0; i < b_sh_wr_iters; i++) {
-         cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
+         cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
          B_ptr[i] += b_gl_rd_delta_o;
        }
        // Only fetch scales if this tile starts a new group
        if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
          int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
          if (s_sh_wr_pred)
-           cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
+           cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
          s_gl_rd += s_gl_rd_delta;
        }
      }
@@ -750,7 +748,7 @@
        // write-out
        if (group_blocks == -1 && last) {
          if (s_sh_wr_pred)
-           cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
+           cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
          cp_async_fence();
        }
        thread_block_reduce();