flash_common.hpp 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  6. #include <torch/python.h>
  7. #include <torch/nn/functional.h>
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <c10/cuda/CUDAGuard.h>
  10. #ifdef OLD_GENERATOR_PATH
  11. #include <ATen/CUDAGeneratorImpl.h>
  12. #else
  13. #include <ATen/cuda/CUDAGeneratorImpl.h>
  14. #endif
  15. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  16. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  17. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  18. namespace flash {
  19. inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state)
  20. {
  21. // Imitate from PyTorch
  22. // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
  23. if (arg.captured_) {
  24. rng_state[0] = static_cast<uint64_t>(*arg.seed_.ptr);
  25. rng_state[1] = static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_);
  26. } else {
  27. rng_state[0] = arg.seed_.val;
  28. rng_state[1] = arg.offset_.val;
  29. }
  30. }
  31. inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
  32. // If we have enough to almost fill the SMs, then just use 1 split
  33. if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
  34. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  35. float max_efficiency = 0.f;
  36. std::vector<float> efficiency;
  37. efficiency.reserve(max_splits);
  38. auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  39. // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
  40. // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
  41. // (i.e. it's 11 splits anyway).
  42. // So we check if the number of blocks per split is the same as the previous num_splits.
  43. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
  44. return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
  45. };
  46. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  47. if (!is_split_eligible(num_splits)) {
  48. efficiency.push_back(0.f);
  49. } else {
  50. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  51. float eff = n_waves / ceil(n_waves);
  52. // printf("num_splits = %d, eff = %f\n", num_splits, eff);
  53. if (eff > max_efficiency) { max_efficiency = eff; }
  54. efficiency.push_back(eff);
  55. }
  56. }
  57. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  58. if (!is_split_eligible(num_splits)) { continue; }
  59. if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
  60. // printf("num_splits chosen = %d\n", num_splits);
  61. return num_splits;
  62. }
  63. }
  64. return 1;
  65. }
  66. int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);
  67. } // namespace flash