1
0

bgmv_config.h 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #pragma once
  2. template <int feat_in, int feat_out, typename in_T, typename out_T,
  3. typename W_T>
  4. void bgmv_kernel(out_T* __restrict__ Y, const in_T* __restrict__ X,
  5. const W_T* __restrict__ W,
  6. const int64_t* __restrict__ indicies, int64_t y_offset,
  7. int64_t full_y_size, int64_t batch_size, int64_t num_layers,
  8. int64_t layer_idx, float scale);
  9. // clang-format off
  10. #define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
  11. f(in_T, out_T, W_T, narrow, 128) \
  12. f(in_T, out_T, W_T, narrow, 256) \
  13. f(in_T, out_T, W_T, narrow, 512) \
  14. f(in_T, out_T, W_T, narrow, 640) \
  15. f(in_T, out_T, W_T, narrow, 768) \
  16. f(in_T, out_T, W_T, narrow, 896) \
  17. f(in_T, out_T, W_T, narrow, 1024) \
  18. f(in_T, out_T, W_T, narrow, 1152) \
  19. f(in_T, out_T, W_T, narrow, 1216) \
  20. f(in_T, out_T, W_T, narrow, 1280) \
  21. f(in_T, out_T, W_T, narrow, 1536) \
  22. f(in_T, out_T, W_T, narrow, 1664) \
  23. f(in_T, out_T, W_T, narrow, 1728) \
  24. f(in_T, out_T, W_T, narrow, 1792) \
  25. f(in_T, out_T, W_T, narrow, 2048) \
  26. f(in_T, out_T, W_T, narrow, 2240) \
  27. f(in_T, out_T, W_T, narrow, 2304) \
  28. f(in_T, out_T, W_T, narrow, 2368) \
  29. f(in_T, out_T, W_T, narrow, 2432) \
  30. f(in_T, out_T, W_T, narrow, 2560) \
  31. f(in_T, out_T, W_T, narrow, 2752) \
  32. f(in_T, out_T, W_T, narrow, 2816) \
  33. f(in_T, out_T, W_T, narrow, 3072) \
  34. f(in_T, out_T, W_T, narrow, 3328) \
  35. f(in_T, out_T, W_T, narrow, 3456) \
  36. f(in_T, out_T, W_T, narrow, 3584) \
  37. f(in_T, out_T, W_T, narrow, 3712) \
  38. f(in_T, out_T, W_T, narrow, 4096) \
  39. f(in_T, out_T, W_T, narrow, 4480) \
  40. f(in_T, out_T, W_T, narrow, 4608) \
  41. f(in_T, out_T, W_T, narrow, 4736) \
  42. f(in_T, out_T, W_T, narrow, 4864) \
  43. f(in_T, out_T, W_T, narrow, 5120) \
  44. f(in_T, out_T, W_T, narrow, 5504) \
  45. f(in_T, out_T, W_T, narrow, 5632) \
  46. f(in_T, out_T, W_T, narrow, 5888) \
  47. f(in_T, out_T, W_T, narrow, 6144) \
  48. f(in_T, out_T, W_T, narrow, 6400) \
  49. f(in_T, out_T, W_T, narrow, 6848) \
  50. f(in_T, out_T, W_T, narrow, 6912) \
  51. f(in_T, out_T, W_T, narrow, 7168) \
  52. f(in_T, out_T, W_T, narrow, 7424) \
  53. f(in_T, out_T, W_T, narrow, 8192) \
  54. f(in_T, out_T, W_T, narrow, 8960) \
  55. f(in_T, out_T, W_T, narrow, 9216) \
  56. f(in_T, out_T, W_T, narrow, 9472) \
  57. f(in_T, out_T, W_T, narrow, 10240) \
  58. f(in_T, out_T, W_T, narrow, 11008) \
  59. f(in_T, out_T, W_T, narrow, 11264) \
  60. f(in_T, out_T, W_T, narrow, 12288) \
  61. f(in_T, out_T, W_T, narrow, 13696) \
  62. f(in_T, out_T, W_T, narrow, 13824) \
  63. f(in_T, out_T, W_T, narrow, 14336) \
  64. f(in_T, out_T, W_T, narrow, 14784) \
  65. f(in_T, out_T, W_T, narrow, 14848) \
  66. f(in_T, out_T, W_T, narrow, 15360) \
  67. f(in_T, out_T, W_T, narrow, 16384) \
  68. f(in_T, out_T, W_T, narrow, 18944) \
  69. f(in_T, out_T, W_T, narrow, 20480) \
  70. f(in_T, out_T, W_T, narrow, 22016) \
  71. f(in_T, out_T, W_T, narrow, 22528) \
  72. f(in_T, out_T, W_T, narrow, 24576) \
  73. f(in_T, out_T, W_T, narrow, 27392) \
  74. f(in_T, out_T, W_T, narrow, 27648) \
  75. f(in_T, out_T, W_T, narrow, 28672) \
  76. f(in_T, out_T, W_T, narrow, 29568) \
  77. f(in_T, out_T, W_T, narrow, 29696) \
  78. f(in_T, out_T, W_T, narrow, 32000) \
  79. f(in_T, out_T, W_T, narrow, 32256) \
  80. f(in_T, out_T, W_T, narrow, 32512) \
  81. f(in_T, out_T, W_T, narrow, 32768) \
  82. f(in_T, out_T, W_T, narrow, 33024) \
  83. f(in_T, out_T, W_T, narrow, 36864) \
  84. f(in_T, out_T, W_T, narrow, 43264) \
  85. f(in_T, out_T, W_T, narrow, 49152) \
  86. f(in_T, out_T, W_T, narrow, 60544) \
  87. f(in_T, out_T, W_T, narrow, 60672) \
  88. f(in_T, out_T, W_T, narrow, 64000) \
  89. f(in_T, out_T, W_T, narrow, 64256) \
  90. f(in_T, out_T, W_T, narrow, 64512) \
  91. f(in_T, out_T, W_T, narrow, 102400) \
  92. f(in_T, out_T, W_T, narrow, 102656) \
  93. f(in_T, out_T, W_T, narrow, 102912) \
  94. f(in_T, out_T, W_T, narrow, 128000) \
  95. f(in_T, out_T, W_T, narrow, 128256) \
  96. f(in_T, out_T, W_T, narrow, 128512) \
  97. f(in_T, out_T, W_T, narrow, 131072) \
  98. // Keep above in sync with aphrodite/lora/layers::SamplerWithLoRA
  99. // Used for defining kernels going from the variety of
  100. // dim in to the narrow dim out
  101. // Using it for the fully sharded column
  102. // parallel LoRA A which splits the rank dim
  103. #define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
  104. f(in_T, out_T, W_T, 128, narrow) \
  105. f(in_T, out_T, W_T, 256, narrow) \
  106. f(in_T, out_T, W_T, 512, narrow) \
  107. f(in_T, out_T, W_T, 640, narrow) \
  108. f(in_T, out_T, W_T, 768, narrow) \
  109. f(in_T, out_T, W_T, 896, narrow) \
  110. f(in_T, out_T, W_T, 1024, narrow) \
  111. f(in_T, out_T, W_T, 1152, narrow) \
  112. f(in_T, out_T, W_T, 1216, narrow) \
  113. f(in_T, out_T, W_T, 1280, narrow) \
  114. f(in_T, out_T, W_T, 1536, narrow) \
  115. f(in_T, out_T, W_T, 1664, narrow) \
  116. f(in_T, out_T, W_T, 1728, narrow) \
  117. f(in_T, out_T, W_T, 1792, narrow) \
  118. f(in_T, out_T, W_T, 2048, narrow) \
  119. f(in_T, out_T, W_T, 2240, narrow) \
  120. f(in_T, out_T, W_T, 2304, narrow) \
  121. f(in_T, out_T, W_T, 2368, narrow) \
  122. f(in_T, out_T, W_T, 2432, narrow) \
  123. f(in_T, out_T, W_T, 2560, narrow) \
  124. f(in_T, out_T, W_T, 2752, narrow) \
  125. f(in_T, out_T, W_T, 2816, narrow) \
  126. f(in_T, out_T, W_T, 3072, narrow) \
  127. f(in_T, out_T, W_T, 3328, narrow) \
  128. f(in_T, out_T, W_T, 3456, narrow) \
  129. f(in_T, out_T, W_T, 3584, narrow) \
  130. f(in_T, out_T, W_T, 3712, narrow) \
  131. f(in_T, out_T, W_T, 4096, narrow) \
  132. f(in_T, out_T, W_T, 4480, narrow) \
  133. f(in_T, out_T, W_T, 4608, narrow) \
  134. f(in_T, out_T, W_T, 4736, narrow) \
  135. f(in_T, out_T, W_T, 4864, narrow) \
  136. f(in_T, out_T, W_T, 5120, narrow) \
  137. f(in_T, out_T, W_T, 5504, narrow) \
  138. f(in_T, out_T, W_T, 5632, narrow) \
  139. f(in_T, out_T, W_T, 5888, narrow) \
  140. f(in_T, out_T, W_T, 6144, narrow) \
  141. f(in_T, out_T, W_T, 6400, narrow) \
  142. f(in_T, out_T, W_T, 6848, narrow) \
  143. f(in_T, out_T, W_T, 6912, narrow) \
  144. f(in_T, out_T, W_T, 7168, narrow) \
  145. f(in_T, out_T, W_T, 7424, narrow) \
  146. f(in_T, out_T, W_T, 8192, narrow) \
  147. f(in_T, out_T, W_T, 8960, narrow) \
  148. f(in_T, out_T, W_T, 9216, narrow) \
  149. f(in_T, out_T, W_T, 9472, narrow) \
  150. f(in_T, out_T, W_T, 10240, narrow) \
  151. f(in_T, out_T, W_T, 11008, narrow) \
  152. f(in_T, out_T, W_T, 11264, narrow) \
  153. f(in_T, out_T, W_T, 12288, narrow) \
  154. f(in_T, out_T, W_T, 13696, narrow) \
  155. f(in_T, out_T, W_T, 13824, narrow) \
  156. f(in_T, out_T, W_T, 14336, narrow) \
  157. f(in_T, out_T, W_T, 14784, narrow) \
  158. f(in_T, out_T, W_T, 14848, narrow) \
  159. f(in_T, out_T, W_T, 15360, narrow) \
  160. f(in_T, out_T, W_T, 16384, narrow) \
  161. f(in_T, out_T, W_T, 18944, narrow) \
  162. f(in_T, out_T, W_T, 20480, narrow) \
  163. f(in_T, out_T, W_T, 22016, narrow) \
  164. f(in_T, out_T, W_T, 22528, narrow) \
  165. f(in_T, out_T, W_T, 24576, narrow) \
  166. f(in_T, out_T, W_T, 27392, narrow) \
  167. f(in_T, out_T, W_T, 27648, narrow) \
  168. f(in_T, out_T, W_T, 28672, narrow) \
  169. f(in_T, out_T, W_T, 29568, narrow) \
  170. f(in_T, out_T, W_T, 29696, narrow) \
  171. f(in_T, out_T, W_T, 32000, narrow) \
  172. f(in_T, out_T, W_T, 32256, narrow) \
  173. f(in_T, out_T, W_T, 32512, narrow) \
  174. f(in_T, out_T, W_T, 32768, narrow) \
  175. f(in_T, out_T, W_T, 33024, narrow) \
  176. f(in_T, out_T, W_T, 36864, narrow) \
  177. f(in_T, out_T, W_T, 43264, narrow) \
  178. f(in_T, out_T, W_T, 49152, narrow) \
  179. f(in_T, out_T, W_T, 60544, narrow) \
  180. f(in_T, out_T, W_T, 60672, narrow) \
  181. f(in_T, out_T, W_T, 64000, narrow) \
  182. f(in_T, out_T, W_T, 64256, narrow) \
  183. f(in_T, out_T, W_T, 64512, narrow) \
  184. f(in_T, out_T, W_T, 102400, narrow) \
  185. f(in_T, out_T, W_T, 102656, narrow) \
  186. f(in_T, out_T, W_T, 102912, narrow) \
  187. f(in_T, out_T, W_T, 128000, narrow) \
  188. f(in_T, out_T, W_T, 128256, narrow) \
  189. f(in_T, out_T, W_T, 128512, narrow) \
  190. f(in_T, out_T, W_T, 131072, narrow) \
  191. // Keep above in sync with aphrodite/lora/layers::SamplerWithLoRA
  192. // Keep this in sync with aphrodite/common/config::LoRAConfig
  193. #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
  194. FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
  195. FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
  196. FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
  197. FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
  198. #define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
  199. FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
  200. FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
  201. FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
  202. f(in_T, out_T, W_T, 8, 64) \
  203. f(in_T, out_T, W_T, 16, 64) \
  204. f(in_T, out_T, W_T, 32, 64) \
  205. f(in_T, out_T, W_T, 64, 64)
  206. // clang-format on