bgmv_config.h 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #pragma once
  2. template <int feat_in, int feat_out, typename in_T, typename out_T,
  3. typename W_T>
  4. void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
  5. const W_T *__restrict__ W,
  6. const int64_t *__restrict__ indicies, int64_t y_offset,
  7. int64_t full_y_size, int64_t batch_size, int64_t num_layers,
  8. int64_t layer_idx, float scale);
  9. // clang-format off
  10. #define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
  11. f(in_T, out_T, W_T, narrow, 128) \
  12. f(in_T, out_T, W_T, narrow, 256) \
  13. f(in_T, out_T, W_T, narrow, 512) \
  14. f(in_T, out_T, W_T, narrow, 1024) \
  15. f(in_T, out_T, W_T, narrow, 1280) \
  16. f(in_T, out_T, W_T, narrow, 1728) \
  17. f(in_T, out_T, W_T, narrow, 1792) \
  18. f(in_T, out_T, W_T, narrow, 2048) \
  19. f(in_T, out_T, W_T, narrow, 2560) \
  20. f(in_T, out_T, W_T, narrow, 2752) \
  21. f(in_T, out_T, W_T, narrow, 3072) \
  22. f(in_T, out_T, W_T, narrow, 3456) \
  23. f(in_T, out_T, W_T, narrow, 3584) \
  24. f(in_T, out_T, W_T, narrow, 4096) \
  25. f(in_T, out_T, W_T, narrow, 5120) \
  26. f(in_T, out_T, W_T, narrow, 5504) \
  27. f(in_T, out_T, W_T, narrow, 5632) \
  28. f(in_T, out_T, W_T, narrow, 6912) \
  29. f(in_T, out_T, W_T, narrow, 7168) \
  30. f(in_T, out_T, W_T, narrow, 8192) \
  31. f(in_T, out_T, W_T, narrow, 9216) \
  32. f(in_T, out_T, W_T, narrow, 10240) \
  33. f(in_T, out_T, W_T, narrow, 11008) \
  34. f(in_T, out_T, W_T, narrow, 12288) \
  35. f(in_T, out_T, W_T, narrow, 13824) \
  36. f(in_T, out_T, W_T, narrow, 14336) \
  37. f(in_T, out_T, W_T, narrow, 16384) \
  38. f(in_T, out_T, W_T, narrow, 20480) \
  39. f(in_T, out_T, W_T, narrow, 28672) \
  40. f(in_T, out_T, W_T, narrow, 32000) \
  41. f(in_T, out_T, W_T, narrow, 32256) \
  42. f(in_T, out_T, W_T, narrow, 32512) \
  43. f(in_T, out_T, W_T, narrow, 32768) \
  44. f(in_T, out_T, W_T, narrow, 33024) \
  45. f(in_T, out_T, W_T, narrow, 36864) \
  46. f(in_T, out_T, W_T, narrow, 49152) \
  47. // Keep above in sync with aphrodite/lora/layers::SamplerWithLoRA
  48. // Keep this in sync with aphrodite/common/config::LoRAConfig
  49. #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
  50. FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
  51. FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
  52. FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
  53. FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
  54. // clang-format on