Browse Source

Fix random state for dropout_layer_norm (#315)

Joel Lamy-Poirier 1 year ago
parent
commit
767b71ccf0
1 changed files with 10 additions and 10 deletions
  1. 10 10
      csrc/layer_norm/ln_api.cpp

+ 10 - 10
csrc/layer_norm/ln_api.cpp

@@ -229,11 +229,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
     // Request the kernel launcher.
     auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
 
-    // Query the kernel-specific launch parameters.
-    launcher(launch_params, true);
-
-    at::Tensor workspace, barrier;
-
     // Set the kernel runtime parameters.
     layer_norm::FwdParams &params = launch_params.params;
     params.rows = rows;
@@ -252,6 +247,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input:
     params.rowscale_const = rowscale_const;
     params.is_rms_norm = is_rms_norm;
 
+    // Query the kernel-specific launch parameters.
+    launcher(launch_params, true);
+
+    at::Tensor workspace, barrier;
+
     if (dropout_p > 0.f) {
         // number of times random will be generated per thread, to offset philox counter in thc random
         // state
@@ -594,11 +594,6 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
     // Request the kernel launcher.
     auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
 
-    // Query the kernel-specific launch parameters.
-    launcher(launch_params, true);
-
-    at::Tensor workspace, barrier;
-
     // Set the kernel runtime parameters.
     layer_norm::FwdParams &params = launch_params.params;
     params.rows = rows;
@@ -621,6 +616,11 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
     params.inverse_cols = 1.f / float(params.cols);
     params.is_rms_norm = is_rms_norm;
 
+    // Query the kernel-specific launch parameters.
+    launcher(launch_params, true);
+
+    at::Tensor workspace, barrier;
+
     if (dropout_p > 0.f) {
         // number of times random will be generated per thread, to offset philox counter in thc random
         // state