1
0

selective_scan_common.h 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda_bf16.h>
  6. #include <cuda_fp16.h>
  7. #include <c10/util/complex.h> // For scalar_value_type
  8. #define MAX_DSTATE 256
  9. using complex_t = c10::complex<float>;
  10. inline __device__ float2 operator+(const float2 & a, const float2 & b){
  11. return {a.x + b.x, a.y + b.y};
  12. }
  13. inline __device__ float3 operator+(const float3 &a, const float3 &b) {
  14. return {a.x + b.x, a.y + b.y, a.z + b.z};
  15. }
  16. inline __device__ float4 operator+(const float4 & a, const float4 & b){
  17. return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
  18. }
  19. ////////////////////////////////////////////////////////////////////////////////////////////////////
  20. template<int BYTES> struct BytesToType {};
  21. template<> struct BytesToType<16> {
  22. using Type = uint4;
  23. static_assert(sizeof(Type) == 16);
  24. };
  25. template<> struct BytesToType<8> {
  26. using Type = uint64_t;
  27. static_assert(sizeof(Type) == 8);
  28. };
  29. template<> struct BytesToType<4> {
  30. using Type = uint32_t;
  31. static_assert(sizeof(Type) == 4);
  32. };
  33. template<> struct BytesToType<2> {
  34. using Type = uint16_t;
  35. static_assert(sizeof(Type) == 2);
  36. };
  37. template<> struct BytesToType<1> {
  38. using Type = uint8_t;
  39. static_assert(sizeof(Type) == 1);
  40. };
  41. ////////////////////////////////////////////////////////////////////////////////////////////////////
  42. template<typename scalar_t, int N>
  43. struct Converter{
  44. static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
  45. #pragma unroll
  46. for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
  47. }
  48. };
  49. template<int N>
  50. struct Converter<at::Half, N>{
  51. static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
  52. static_assert(N % 2 == 0);
  53. auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
  54. auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
  55. #pragma unroll
  56. for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
  57. }
  58. };
  59. #if __CUDA_ARCH__ >= 800
  60. template<int N>
  61. struct Converter<at::BFloat16, N>{
  62. static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
  63. static_assert(N % 2 == 0);
  64. auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
  65. auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
  66. #pragma unroll
  67. for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
  68. }
  69. };
  70. #endif
  71. ////////////////////////////////////////////////////////////////////////////////////////////////////
  72. // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
  73. // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
  74. __device__ __forceinline__ complex_t cexp2f(complex_t z) {
  75. float t = exp2f(z.real_);
  76. float c, s;
  77. sincosf(z.imag_, &s, &c);
  78. return complex_t(c * t, s * t);
  79. }
  80. __device__ __forceinline__ complex_t cexpf(complex_t z) {
  81. float t = expf(z.real_);
  82. float c, s;
  83. sincosf(z.imag_, &s, &c);
  84. return complex_t(c * t, s * t);
  85. }
  86. template<typename scalar_t> struct SSMScanOp;
  87. template<>
  88. struct SSMScanOp<float> {
  89. __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
  90. return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
  91. }
  92. };
  93. template<>
  94. struct SSMScanOp<complex_t> {
  95. __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
  96. complex_t a0 = complex_t(ab0.x, ab0.y);
  97. complex_t b0 = complex_t(ab0.z, ab0.w);
  98. complex_t a1 = complex_t(ab1.x, ab1.y);
  99. complex_t b1 = complex_t(ab1.z, ab1.w);
  100. complex_t out_a = a1 * a0;
  101. complex_t out_b = a1 * b0 + b1;
  102. return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
  103. }
  104. };
  105. // A stateful callback functor that maintains a running prefix to be applied
  106. // during consecutive scan operations.
  107. template <typename scalar_t> struct SSMScanPrefixCallbackOp {
  108. using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
  109. scan_t running_prefix;
  110. // Constructor
  111. __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
  112. // Callback operator to be entered by the first warp of threads in the block.
  113. // Thread-0 is responsible for returning a value for seeding the block-wide scan.
  114. __device__ scan_t operator()(scan_t block_aggregate) {
  115. scan_t old_prefix = running_prefix;
  116. running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
  117. return old_prefix;
  118. }
  119. };
  120. ////////////////////////////////////////////////////////////////////////////////////////////////////
  121. template<typename Ktraits>
  122. inline __device__ void load_input(typename Ktraits::input_t *u,
  123. typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
  124. typename Ktraits::BlockLoadT::TempStorage &smem_load,
  125. int seqlen) {
  126. if constexpr (Ktraits::kIsEvenLen) {
  127. auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
  128. using vec_t = typename Ktraits::vec_t;
  129. Ktraits::BlockLoadVecT(smem_load_vec).Load(
  130. reinterpret_cast<vec_t*>(u),
  131. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
  132. );
  133. } else {
  134. Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
  135. }
  136. }
  137. template<typename Ktraits>
  138. inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
  139. typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
  140. typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
  141. int seqlen) {
  142. constexpr int kNItems = Ktraits::kNItems;
  143. if constexpr (!Ktraits::kIsComplex) {
  144. typename Ktraits::input_t B_vals_load[kNItems];
  145. if constexpr (Ktraits::kIsEvenLen) {
  146. auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
  147. using vec_t = typename Ktraits::vec_t;
  148. Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
  149. reinterpret_cast<vec_t*>(Bvar),
  150. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
  151. );
  152. } else {
  153. Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
  154. }
  155. // #pragma unroll
  156. // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
  157. Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
  158. } else {
  159. typename Ktraits::input_t B_vals_load[kNItems * 2];
  160. if constexpr (Ktraits::kIsEvenLen) {
  161. auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
  162. using vec_t = typename Ktraits::vec_t;
  163. Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
  164. reinterpret_cast<vec_t*>(Bvar),
  165. reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
  166. );
  167. } else {
  168. Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
  169. }
  170. #pragma unroll
  171. for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
  172. }
  173. }
  174. template<typename Ktraits>
  175. inline __device__ void store_output(typename Ktraits::input_t *out,
  176. const float (&out_vals)[Ktraits::kNItems],
  177. typename Ktraits::BlockStoreT::TempStorage &smem_store,
  178. int seqlen) {
  179. typename Ktraits::input_t write_vals[Ktraits::kNItems];
  180. #pragma unroll
  181. for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
  182. if constexpr (Ktraits::kIsEvenLen) {
  183. auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
  184. using vec_t = typename Ktraits::vec_t;
  185. Ktraits::BlockStoreVecT(smem_store_vec).Store(
  186. reinterpret_cast<vec_t*>(out),
  187. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
  188. );
  189. } else {
  190. Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
  191. }
  192. }