flash_common.hpp 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  6. #include <torch/python.h>
  7. #include <torch/nn/functional.h>
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <c10/cuda/CUDAGuard.h>
  10. #ifdef OLD_GENERATOR_PATH
  11. #include <ATen/CUDAGeneratorImpl.h>
  12. #else
  13. #include <ATen/cuda/CUDAGeneratorImpl.h>
  14. #endif
  15. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  16. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  17. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  18. namespace flash {
  19. // Copy from PyTorch
  20. // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
  21. static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
  22. if (arg.captured_) {
  23. // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
  24. // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
  25. // For most threads' reads it will hit in cache, so it shouldn't hurt performance.
  26. return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
  27. } else {
  28. return std::make_tuple(arg.seed_.val, arg.offset_.val);
  29. }
  30. }
  31. } // namespace flash