philox.cuh 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. // Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
  2. #pragma once
  3. // Philox CUDA.
  4. namespace flash {
  5. struct ull2 {
  6. unsigned long long x;
  7. unsigned long long y;
  8. };
  9. __forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
  10. uint2 *res;
  11. unsigned long long tmp;
  12. asm ("mul.wide.u32 %0, %1, %2;\n\t"
  13. : "=l"(tmp)
  14. : "r"(a), "r"(b));
  15. res = (uint2*)(&tmp);
  16. return *res;
  17. }
  18. __forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
  19. constexpr unsigned long kPhiloxSA = 0xD2511F53;
  20. constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
  21. uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
  22. uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
  23. uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
  24. return ret;
  25. }
  26. __forceinline__ __device__ uint4 philox(unsigned long long seed,
  27. unsigned long long subsequence,
  28. unsigned long long offset) {
  29. constexpr unsigned long kPhilox10A = 0x9E3779B9;
  30. constexpr unsigned long kPhilox10B = 0xBB67AE85;
  31. uint2 key = reinterpret_cast<uint2&>(seed);
  32. uint4 counter;
  33. ull2 *tmp = reinterpret_cast<ull2*>(&counter);
  34. tmp->x = offset;
  35. tmp->y = subsequence;
  36. #pragma unroll
  37. for (int i = 0; i < 6; i++) {
  38. counter = philox_single_round(counter, key);
  39. key.x += (kPhilox10A);
  40. key.y += (kPhilox10B);
  41. }
  42. uint4 output = philox_single_round(counter, key);
  43. return output;
  44. }
  45. } // namespace flash