dtype_complex64.cuh 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #pragma once
  2. #include "attention_generic.cuh"
  3. #include <cuComplex.h>
  4. namespace aphrodite {
  5. // Define custom complex64 vector data types.
  6. struct Complex4_ {
  7. cuFloatComplex x;
  8. cuFloatComplex y;
  9. };
  10. struct Complex8_ {
  11. cuFloatComplex x;
  12. cuFloatComplex y;
  13. cuFloatComplex z;
  14. cuFloatComplex w;
  15. };
  16. // Complex64 vector types for Q, K, V.
  17. template<>
  18. struct Vec<cuFloatComplex, 1> {
  19. using Type = cuFloatComplex;
  20. };
  21. template<>
  22. struct Vec<cuFloatComplex, 2> {
  23. using Type = Complex4_;
  24. };
  25. template<>
  26. struct Vec<cuFloatComplex, 4> {
  27. using Type = Complex8_;
  28. };
  29. // Complex64 accumulator vector types corresponding to Vec.
  30. template<>
  31. struct FloatVec<cuFloatComplex> {
  32. using Type = cuFloatComplex;
  33. };
  34. template<>
  35. struct FloatVec<Complex4_> {
  36. using Type = Complex4_;
  37. };
  38. template<>
  39. struct FloatVec<Complex8_> {
  40. using Type = Complex8_;
  41. };
  42. // Vector addition.
  43. inline __device__ cuFloatComplex add(cuFloatComplex a, cuFloatComplex b) {
  44. return cuCaddf(a, b);
  45. }
  46. inline __device__ Complex4_ add(Complex4_ a, Complex4_ b) {
  47. Complex4_ c;
  48. c.x = cuCaddf(a.x, b.x);
  49. c.y = cuCaddf(a.y, b.y);
  50. return c;
  51. }
  52. inline __device__ Complex8_ add(Complex8_ a, Complex8_ b) {
  53. Complex8_ c;
  54. c.x = cuCaddf(a.x, b.x);
  55. c.y = cuCaddf(a.y, b.y);
  56. c.z = cuCaddf(a.z, b.z);
  57. c.w = cuCaddf(a.w, b.w);
  58. return c;
  59. }
  60. // Vector multiplication.
  61. template<>
  62. inline __device__ cuFloatComplex mul(cuFloatComplex a, cuFloatComplex b) {
  63. return cuCmulf(a, b);
  64. }
  65. template<>
  66. inline __device__ Complex4_ mul(Complex4_ a, Complex4_ b) {
  67. Complex4_ c;
  68. c.x = cuCmulf(a.x, b.x);
  69. c.y = cuCmulf(a.y, b.y);
  70. return c;
  71. }
  72. template<>
  73. inline __device__ Complex4_ mul(cuFloatComplex a, Complex4_ b) {
  74. Complex4_ c;
  75. c.x = cuCmulf(a, b.x);
  76. c.y = cuCmulf(a, b.y);
  77. return c;
  78. }
  79. template<>
  80. inline __device__ Complex8_ mul(Complex8_ a, Complex8_ b) {
  81. Complex8_ c;
  82. c.x = cuCmulf(a.x, b.x);
  83. c.y = cuCmulf(a.y, b.y);
  84. c.z = cuCmulf(a.z, b.z);
  85. c.w = cuCmulf(a.w, b.w);
  86. return c;
  87. }
  88. template<>
  89. inline __device__ Complex8_ mul(cuFloatComplex a, Complex8_ b) {
  90. Complex8_ c;
  91. c.x = cuCmulf(a, b.x);
  92. c.y = cuCmulf(a, b.y);
  93. c.z = cuCmulf(a, b.z);
  94. c.w = cuCmulf(a, b.w);
  95. return c;
  96. }
  97. // Vector fused multiply-add.
  98. inline __device__ cuFloatComplex fma(cuFloatComplex a, cuFloatComplex b, cuFloatComplex c) {
  99. return cuCfmaf(a, b, c);
  100. }
  101. inline __device__ Complex4_ fma(Complex4_ a, Complex4_ b, Complex4_ c) {
  102. Complex4_ d;
  103. d.x = cuCfmaf(a.x, b.x, c.x);
  104. d.y = cuCfmaf(a.y, b.y, c.y);
  105. return d;
  106. }
  107. inline __device__ Complex4_ fma(cuFloatComplex a, Complex4_ b, Complex4_ c) {
  108. Complex4_ d;
  109. d.x = cuCfmaf(a, b.x, c.x);
  110. d.y = cuCfmaf(a, b.y, c.y);
  111. return d;
  112. }
  113. inline __device__ Complex8_ fma(Complex8_ a, Complex8_ b, Complex8_ c) {
  114. Complex8_ d;
  115. d.x = cuCfmaf(a.x, b.x, c.x);
  116. d.y = cuCfmaf(a.y, b.y, c.y);
  117. d.z = cuCfmaf(a.z, b.z, c.z);
  118. d.w = cuCfmaf(a.w, b.w, c.w);
  119. return d;
  120. }
  121. inline __device__ Complex8_ fma(cuFloatComplex a, Complex8_ b, Complex8_ c) {
  122. Complex8_ d;
  123. d.x = cuCfmaf(a, b.x, c.x);
  124. d.y = cuCfmaf(a, b.y, c.y);
  125. d.z = cuCfmaf(a, b.z, c.z);
  126. d.w = cuCfmaf(a, b.w, c.w);
  127. return d;
  128. }
  129. template<>
  130. inline __device__ cuFloatComplex sum(cuFloatComplex v) {
  131. return v;
  132. }
  133. template<>
  134. inline __device__ Complex4_ sum(Complex4_ v) {
  135. Complex4_ acc;
  136. acc.x = cuCaddf(v.x, v.y);
  137. acc.y = make_cuFloatComplex(0.f, 0.f);
  138. return acc;
  139. }
  140. template<>
  141. inline __device__ Complex8_ sum(Complex8_ v) {
  142. Complex4_ acc1;
  143. Complex4_ acc2;
  144. acc1.x = cuCaddf(v.x, v.y);
  145. acc1.y = cuCaddf(v.z, v.w);
  146. acc2.x = make_cuFloatComplex(0.f, 0.f);
  147. acc2.y = make_cuFloatComplex(0.f, 0.f);
  148. return add(acc1, acc2);
  149. }
  150. inline __device__ cuFloatComplex dot(cuFloatComplex a, cuFloatComplex b) {
  151. return cuCmulf(a, b);
  152. }
  153. inline __device__ Complex4_ dot(Complex4_ a, Complex4_ b) {
  154. Complex4_ c;
  155. c.x = cuCmulf(a.x, b.x);
  156. c.y = cuCmulf(a.y, b.y);
  157. return c;
  158. }
  159. inline __device__ Complex8_ dot(Complex8_ a, Complex8_ b) {
  160. Complex8_ c;
  161. c.x = cuCmulf(a.x, b.x);
  162. c.y = cuCmulf(a.y, b.y);
  163. c.z = cuCmulf(a.z, b.z);
  164. c.w = cuCmulf(a.w, b.w);
  165. return c;
  166. }
  167. inline __device__ void from_float(cuFloatComplex& dst, cuFloatComplex src) {
  168. dst = src;
  169. }
  170. inline __device__ void from_float(Complex4_& dst, Complex4_ src) {
  171. dst = src;
  172. }
  173. inline __device__ void from_float(Complex8_& dst, Complex8_ src) {
  174. dst = src;
  175. }
  176. inline __device__ cuFloatComplex to_float(cuFloatComplex u) {
  177. return u;
  178. }
  179. inline __device__ Complex4_ to_float(Complex4_ u) {
  180. return u;
  181. }
  182. inline __device__ Complex8_ to_float(Complex8_ u) {
  183. return u;
  184. }
  185. } // namespace aphrodite