123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- """Tests for the AWQ Triton kernel.
- Run `pytest tests/kernels/test_awq_triton.py`.
- """
- import pytest
- import torch
- from aphrodite.quantization.awq_triton import (
- AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
- device = "cuda"
- def reverse_awq_order(t: torch.Tensor):
- bits = 4
- AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
- reverse_order_tensor = torch.arange(
- t.shape[-1],
- dtype=torch.int32,
- device=t.device,
- )
- reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
- reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
- reverse_order_tensor = reverse_order_tensor.view(-1)
- t = t[:, reverse_order_tensor] & 0xF
- return t
- # qweights - [R , C // 8], int32
- # scales - [R // G, C ], float16
- # zeros - [R // G, C // 8], int32
- def awq_dequantize_torch(
- qweight: torch.Tensor,
- scales: torch.Tensor,
- qzeros: torch.Tensor,
- group_size: int,
- ) -> torch.Tensor:
- if group_size == -1:
- group_size = qweight.shape[0]
- bits = 4
- shifts = torch.arange(0, 32, bits, device=qzeros.device)
- iweights = torch.bitwise_right_shift(
- qweight[:, :, None], shifts[None, None, :]
- ).to(torch.int8)
- iweights = iweights.view(iweights.shape[0], -1)
- zeros = torch.bitwise_right_shift(
- qzeros[:, :, None], shifts[None, None, :]
- ).to(torch.int8)
- zeros = zeros.view(qzeros.shape[0], -1)
- zeros = reverse_awq_order(zeros)
- iweights = reverse_awq_order(iweights)
- iweights = torch.bitwise_and(iweights, (2**bits) - 1)
- zeros = torch.bitwise_and(zeros, (2**bits) - 1)
- scales = scales.repeat_interleave(group_size, dim=0)
- zeros = zeros.repeat_interleave(group_size, dim=0)
- return (iweights - zeros) * scales
- # qweights - [R , C // 8], int32
- # scales - [R // G, C ], float16
- # zeros - [R // G, C // 8], int32
- @pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
- @pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
- @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
- def test_dequantize(qweight_rows, qweight_cols, group_size):
- if group_size == -1:
- group_size = qweight_rows
- qweight_dtype = torch.int32
- scales_rows = qweight_rows // group_size
- scales_cols = qweight_cols * 8
- scales_dtype = torch.float16
- zeros_rows = scales_rows
- zeros_cols = qweight_cols
- zeros_dtype = torch.int32
- torch.manual_seed(0)
- qweight = torch.randint(
- 0,
- torch.iinfo(torch.int32).max,
- (qweight_rows, qweight_cols),
- dtype=qweight_dtype,
- device=device,
- )
- scales = torch.rand(
- scales_rows, scales_cols, dtype=scales_dtype, device=device
- )
- zeros = torch.randint(
- 0,
- torch.iinfo(torch.int32).max,
- (zeros_rows, zeros_cols),
- dtype=zeros_dtype,
- device=device,
- )
- iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
- assert not torch.any(torch.isinf(iweights_triton)) and not torch.any(
- torch.isnan(iweights_triton)
- )
- iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
- torch.testing.assert_close(iweights_triton, iweights_torch)
- # input - [N, K]
- # qweight - [K, M // 8]
- # qzeros - [K // G, M // 8]
- # scales - [K // G, M]
- @pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
- @pytest.mark.parametrize("K", [128])
- @pytest.mark.parametrize("M", [16, 24, 32])
- @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
- @pytest.mark.parametrize("splitK", [1, 8])
- def test_gemm(N, K, M, splitK, group_size):
- if group_size == -1:
- group_size = K
- split_k_iters = splitK
- input_rows = N
- input_cols = K
- input_dtype = torch.float32
- qweight_rows = input_cols
- qweight_cols = M // 8
- scales_rows = qweight_rows // group_size
- scales_cols = M
- scales_dtype = torch.float32
- qzeros_rows = scales_rows
- qzeros_cols = qweight_cols
- torch.manual_seed(0)
- input = torch.rand(
- (input_rows, input_cols), dtype=input_dtype, device=device
- )
- qweight = torch.randint(
- 0,
- torch.iinfo(torch.int32).max,
- (qweight_rows, qweight_cols),
- device=device,
- )
- qzeros = torch.randint(
- 0,
- torch.iinfo(torch.int32).max,
- (qzeros_rows, qzeros_cols),
- device=device,
- )
- scales = torch.rand(
- (scales_rows, scales_cols), dtype=scales_dtype, device=device
- )
- output_triton = awq_gemm_triton(
- input, qweight, scales, qzeros, split_k_iters
- )
- assert not torch.any(torch.isinf(output_triton)) and not torch.any(
- torch.isnan(output_triton)
- )
- dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
- output_torch = torch.matmul(input, dequantized_weights)
- assert not torch.any(torch.isinf(output_torch)) and not torch.any(
- torch.isnan(output_torch)
- )
- torch.testing.assert_close(
- output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1
- )
|