column_remap.cu 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. // Adapted from turboderp exllama: https://github.com/turboderp/exllama
  2. #include "column_remap.cuh"
  3. #include "../util.cuh"
  4. const int SHUF_BLOCKSIZE_X = 256;
  5. const int SHUF_BLOCKSIZE_Y = 16;
  6. __global__ void column_remap_kernel
  7. (
  8. const half* __restrict__ x,
  9. half* __restrict__ x_new,
  10. const int x_width,
  11. const int x_height,
  12. const uint32_t* x_map
  13. )
  14. {
  15. int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
  16. int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
  17. if (x_column >= x_width) return;
  18. //if (x_row >= x_height) return;
  19. int x_stride = x_width;
  20. int x_idx = x_row * x_stride + x_column;
  21. int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
  22. int x_idx_end = x_row_end * x_stride + x_column;
  23. int s_column = x_map[x_column];
  24. int s_idx = x_row * x_stride + s_column;
  25. while (x_idx < x_idx_end)
  26. {
  27. x_new[x_idx] = x[s_idx];
  28. x_idx += x_stride;
  29. s_idx += x_stride;
  30. }
  31. }
  32. // Remap columns in x to correspond to sequential group index before matmul
  33. //
  34. // perform x -> seq_x such that seq_x @ seq_w == x @ w
  35. void column_remap_cuda
  36. (
  37. const half* x,
  38. half* x_new,
  39. const int x_height,
  40. const int x_width,
  41. const uint32_t* x_map
  42. )
  43. {
  44. dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
  45. dim3 blocks
  46. (
  47. (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
  48. (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
  49. 1
  50. );
  51. column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
  52. }