dropout.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "philox.cuh"
  6. #include "utils.h"
  7. namespace flash {
  8. struct Dropout {
  9. const unsigned long long seed, offset;
  10. const uint8_t p_dropout_in_uint8_t;
  11. __forceinline__ __device__ Dropout(const unsigned long long seed,
  12. const unsigned long long offset,
  13. const uint8_t p_dropout_in_uint8_t,
  14. const int bid, const int hid,
  15. const int tid, const int nheads)
  16. : seed(seed),
  17. offset(offset + (bid * nheads + hid) * 32 + tid % 32),
  18. p_dropout_in_uint8_t(p_dropout_in_uint8_t) {}
  19. template <bool encode_dropout_in_sign_bit = false, typename Engine,
  20. typename Layout>
  21. __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout>& tensor_,
  22. int block_row_start,
  23. int block_col_start,
  24. int block_row_stride) {
  25. // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
  26. Tensor tensor = make_tensor(
  27. tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
  28. using T = typename Engine::value_type;
  29. auto encode_dropout = [](bool keep, T val) {
  30. return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
  31. };
  32. static_assert(decltype(size<2>(tensor))::value % 2 == 0);
  33. const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
  34. const uint32_t p_dropout_8bit_in_uint32_t =
  35. (uint32_t(p_dropout_8bit_in_uint16_t) << 16) |
  36. uint32_t(p_dropout_8bit_in_uint16_t);
  37. // if (cute::thread0()) { printf("threshold2 = 0x%x\n",
  38. // p_dropout_8bit_in_uint32_t); }
  39. #pragma unroll
  40. for (int m = 0; m < size<1>(tensor);
  41. ++m, block_row_start += block_row_stride) {
  42. uint2 rowcol = make_uint2(block_row_start, block_col_start);
  43. #pragma unroll
  44. for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
  45. // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col =
  46. // %d\n", m, n, int(rowcol.x), int(rowcol.y));}
  47. uint4 random_uint4 = flash::philox(
  48. seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
  49. // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n",
  50. // random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
  51. uint8_t(&rnd_8)[16] = reinterpret_cast<uint8_t(&)[16]>(random_uint4);
  52. // Special implementation for 16-bit types: we duplicate the threshold
  53. // to the low and high 16 bits of a 32-bit value, then use the f16x2
  54. // comparison instruction to get a mask. The low 16 bits of the mask
  55. // will be either 0xffff or 0x0000, and the high 16 bits will be either
  56. // 0xffff or 0x0000, depending on whether the random value is less than
  57. // the threshold. We then do a bit-wise AND between the mask and the
  58. // original value (in 32-bit). We're exploiting the fact that floating
  59. // point comparison is equivalent to integer comparison, since we're
  60. // comparing unsigned integers whose top 8-bits are zero.
  61. if (!encode_dropout_in_sign_bit &&
  62. (std::is_same<T, cutlass::half_t>::value ||
  63. std::is_same<T, cutlass::bfloat16_t>::value)) {
  64. uint16_t rnd_16[16];
  65. #pragma unroll
  66. for (int i = 0; i < 16; i++) {
  67. rnd_16[i] = uint16_t(rnd_8[i]);
  68. }
  69. uint32_t(&rnd_32)[8] = reinterpret_cast<uint32_t(&)[8]>(rnd_16);
  70. #pragma unroll
  71. for (int j = 0; j < 2; j++) {
  72. Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
  73. // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j *
  74. // 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } if
  75. // (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n",
  76. // tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
  77. #pragma unroll
  78. for (int i = 0; i < 4; i++) {
  79. uint32_t mask;
  80. asm volatile("set.le.u32.f16x2 %0, %1, %2;\n"
  81. : "=r"(mask)
  82. : "r"(rnd_32[j * 4 + i]),
  83. "r"(p_dropout_8bit_in_uint32_t));
  84. tensor_uint32(i) &= mask;
  85. }
  86. // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x,
  87. // 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2),
  88. // tensor_uint32(3)); }
  89. }
  90. } else {
  91. #pragma unroll
  92. for (int j = 0; j < 2; j++) {
  93. #pragma unroll
  94. for (int i = 0; i < 8; i++) {
  95. tensor(i, m, n * 2 + j) =
  96. encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t,
  97. tensor(i, m, n * 2 + j));
  98. }
  99. Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
  100. // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x,
  101. // 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2),
  102. // tensor_uint32(3)); }
  103. }
  104. }
  105. // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0))
  106. // {
  107. // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x,
  108. // rnd_8.y, rnd_8.z, rnd_8.w);
  109. // // }
  110. }
  111. }
  112. }
  113. };
  114. } // namespace flash