1
0

permute_cols.cu 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cuda_fp16.h>
  5. static constexpr int default_threads = 256;
  6. static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
  7. // For a given "a" of size [M,K] performs a permutation of the K columns based
  8. // on the given "perm" indices.
  9. // Currently only supports 16bit types (since we permute half types)
  10. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  11. int const* __restrict__ perm_int_ptr,
  12. int4* __restrict__ out_int4_ptr, int size_m,
  13. int size_k, int block_rows) {
  14. int start_row = block_rows * blockIdx.x;
  15. int finish_row = start_row + block_rows;
  16. if (finish_row > size_m) {
  17. finish_row = size_m;
  18. }
  19. int cur_block_rows = std::max(finish_row - start_row, 0);
  20. int row_stride = size_k * sizeof(half) / 16;
  21. auto permute_row = [&](int row) {
  22. int iters = size_k / default_threads;
  23. int rest = size_k % default_threads;
  24. int offset = row * row_stride;
  25. half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
  26. half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
  27. int base_k = 0;
  28. for (int i = 0; i < iters; i++) {
  29. int cur_k = base_k + threadIdx.x;
  30. int src_pos = perm_int_ptr[cur_k];
  31. out_half[cur_k] = a_row_half[src_pos];
  32. base_k += default_threads;
  33. }
  34. if (rest) {
  35. if (threadIdx.x < rest) {
  36. int cur_k = base_k + threadIdx.x;
  37. int src_pos = perm_int_ptr[cur_k];
  38. out_half[cur_k] = a_row_half[src_pos];
  39. }
  40. }
  41. };
  42. for (int i = 0; i < cur_block_rows; i++) {
  43. int cur_row = start_row + i;
  44. if (cur_row < size_m) {
  45. permute_row(cur_row);
  46. }
  47. }
  48. }
  49. // More efficient version of A[..., perm]
  50. // taken from gptq_marlin.cu
  51. torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
  52. const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  53. auto dev = A.get_device();
  54. auto stream = at::cuda::getCurrentCUDAStream(dev);
  55. TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
  56. "Currently only 16bit types are supported");
  57. TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
  58. TORCH_CHECK(A.size(-1) % 8 == 0,
  59. "A columns must be a multiple of 8 (128bits)");
  60. auto A_2d = A.view({-1, A.size(-1)});
  61. torch::Tensor D = torch::empty_like(A);
  62. int sms;
  63. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  64. int block_rows = div_ceil(A_2d.size(0), sms);
  65. permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
  66. reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
  67. perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
  68. A_2d.size(0), A_2d.size(1), block_rows);
  69. return D;
  70. }