test_awq_triton.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. """Tests for the AWQ Triton kernel.
  2. Run `pytest tests/kernels/test_awq_triton.py`.
  3. """
  4. import pytest
  5. import torch
  6. from aphrodite.quantization.awq_triton import (
  7. AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
  8. device = "cuda"
  9. def reverse_awq_order(t: torch.Tensor):
  10. bits = 4
  11. AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
  12. reverse_order_tensor = torch.arange(
  13. t.shape[-1],
  14. dtype=torch.int32,
  15. device=t.device,
  16. )
  17. reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
  18. reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
  19. reverse_order_tensor = reverse_order_tensor.view(-1)
  20. t = t[:, reverse_order_tensor] & 0xF
  21. return t
  22. # qweights - [R , C // 8], int32
  23. # scales - [R // G, C ], float16
  24. # zeros - [R // G, C // 8], int32
  25. def awq_dequantize_torch(
  26. qweight: torch.Tensor,
  27. scales: torch.Tensor,
  28. qzeros: torch.Tensor,
  29. group_size: int,
  30. ) -> torch.Tensor:
  31. if group_size == -1:
  32. group_size = qweight.shape[0]
  33. bits = 4
  34. shifts = torch.arange(0, 32, bits, device=qzeros.device)
  35. iweights = torch.bitwise_right_shift(
  36. qweight[:, :, None], shifts[None, None, :]
  37. ).to(torch.int8)
  38. iweights = iweights.view(iweights.shape[0], -1)
  39. zeros = torch.bitwise_right_shift(
  40. qzeros[:, :, None], shifts[None, None, :]
  41. ).to(torch.int8)
  42. zeros = zeros.view(qzeros.shape[0], -1)
  43. zeros = reverse_awq_order(zeros)
  44. iweights = reverse_awq_order(iweights)
  45. iweights = torch.bitwise_and(iweights, (2**bits) - 1)
  46. zeros = torch.bitwise_and(zeros, (2**bits) - 1)
  47. scales = scales.repeat_interleave(group_size, dim=0)
  48. zeros = zeros.repeat_interleave(group_size, dim=0)
  49. return (iweights - zeros) * scales
  50. # qweights - [R , C // 8], int32
  51. # scales - [R // G, C ], float16
  52. # zeros - [R // G, C // 8], int32
  53. @pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
  54. @pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
  55. @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
  56. def test_dequantize(qweight_rows, qweight_cols, group_size):
  57. if group_size == -1:
  58. group_size = qweight_rows
  59. qweight_dtype = torch.int32
  60. scales_rows = qweight_rows // group_size
  61. scales_cols = qweight_cols * 8
  62. scales_dtype = torch.float16
  63. zeros_rows = scales_rows
  64. zeros_cols = qweight_cols
  65. zeros_dtype = torch.int32
  66. torch.manual_seed(0)
  67. qweight = torch.randint(
  68. 0,
  69. torch.iinfo(torch.int32).max,
  70. (qweight_rows, qweight_cols),
  71. dtype=qweight_dtype,
  72. device=device,
  73. )
  74. scales = torch.rand(
  75. scales_rows, scales_cols, dtype=scales_dtype, device=device
  76. )
  77. zeros = torch.randint(
  78. 0,
  79. torch.iinfo(torch.int32).max,
  80. (zeros_rows, zeros_cols),
  81. dtype=zeros_dtype,
  82. device=device,
  83. )
  84. iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
  85. assert not torch.any(torch.isinf(iweights_triton)) and not torch.any(
  86. torch.isnan(iweights_triton)
  87. )
  88. iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
  89. torch.testing.assert_close(iweights_triton, iweights_torch)
  90. # input - [N, K]
  91. # qweight - [K, M // 8]
  92. # qzeros - [K // G, M // 8]
  93. # scales - [K // G, M]
  94. @pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
  95. @pytest.mark.parametrize("K", [128])
  96. @pytest.mark.parametrize("M", [16, 24, 32])
  97. @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
  98. @pytest.mark.parametrize("splitK", [1, 8])
  99. def test_gemm(N, K, M, splitK, group_size):
  100. if group_size == -1:
  101. group_size = K
  102. split_k_iters = splitK
  103. input_rows = N
  104. input_cols = K
  105. input_dtype = torch.float32
  106. qweight_rows = input_cols
  107. qweight_cols = M // 8
  108. scales_rows = qweight_rows // group_size
  109. scales_cols = M
  110. scales_dtype = torch.float32
  111. qzeros_rows = scales_rows
  112. qzeros_cols = qweight_cols
  113. torch.manual_seed(0)
  114. input = torch.rand(
  115. (input_rows, input_cols), dtype=input_dtype, device=device
  116. )
  117. qweight = torch.randint(
  118. 0,
  119. torch.iinfo(torch.int32).max,
  120. (qweight_rows, qweight_cols),
  121. device=device,
  122. )
  123. qzeros = torch.randint(
  124. 0,
  125. torch.iinfo(torch.int32).max,
  126. (qzeros_rows, qzeros_cols),
  127. device=device,
  128. )
  129. scales = torch.rand(
  130. (scales_rows, scales_cols), dtype=scales_dtype, device=device
  131. )
  132. output_triton = awq_gemm_triton(
  133. input, qweight, scales, qzeros, split_k_iters
  134. )
  135. assert not torch.any(torch.isinf(output_triton)) and not torch.any(
  136. torch.isnan(output_triton)
  137. )
  138. dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
  139. output_torch = torch.matmul(input, dequantized_weights)
  140. assert not torch.any(torch.isinf(output_torch)) and not torch.any(
  141. torch.isnan(output_torch)
  142. )
  143. torch.testing.assert_close(
  144. output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1
  145. )