/****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. #include #include #include #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") namespace flash { // Copy from PyTorch // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 static std::tuple unpack(at::PhiloxCudaState arg) { if (arg.captured_) { // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. // For most threads' reads it will hit in cache, so it shouldn't hurt performance. return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); } else { return std::make_tuple(arg.seed_.val, arg.offset_.val); } } } // namespace flash