aqlm.py 9.1 KB


  1. import os
  2. import sys
  3. from typing import Optional
  4. import torch
  5. import torch.nn.functional as F
  6. from aphrodite import _custom_ops as ops
  7. from aphrodite.common.utils import FlexibleArgumentParser
  8. from aphrodite.quantization.aqlm import (dequantize_weight,
  9. generic_dequantize_gemm,
  10. get_int_dtype,
  11. optimized_dequantize_gemm)
  12. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  13. def torch_mult(
  14. input: torch.Tensor, # [..., in_features]
  15. weights: torch.Tensor,
  16. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  17. ) -> torch.Tensor:
  18. output = F.linear(input, weights)
  19. return output
  20. def dequant_out_scale(
  21. input: torch.Tensor, # [..., in_features]
  22. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  23. codebooks: torch.
  24. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  25. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  26. output_partition_sizes: torch.IntTensor,
  27. bias: Optional[torch.Tensor],
  28. ) -> torch.Tensor:
  29. weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
  30. if bias is None:
  31. output = F.linear(input, weights, bias)
  32. orig_shape = output.shape
  33. flattened_output = output.view(-1, output.size(-1))
  34. f_scales = scales.view(-1, scales.shape[0])
  35. b_scales = f_scales.expand(flattened_output.shape[0], -1)
  36. flattened_output *= b_scales
  37. return flattened_output.view(orig_shape)
  38. else:
  39. b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
  40. -1, weights.shape[1])
  41. weights *= b_scales
  42. return F.linear(input, weights, bias)
  43. def dequant_weight_scale(
  44. input: torch.Tensor, # [..., in_features]
  45. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  46. codebooks: torch.
  47. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  48. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  49. output_partition_sizes: torch.IntTensor,
  50. bias: Optional[torch.Tensor],
  51. ) -> torch.Tensor:
  52. weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
  53. b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
  54. -1, weights.shape[1])
  55. weights *= b_scales
  56. return F.linear(input, weights, bias)
  57. def dequant_no_scale(
  58. input: torch.Tensor, # [..., in_features]
  59. codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
  60. codebooks: torch.
  61. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
  62. scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
  63. output_partition_sizes: torch.IntTensor,
  64. bias: Optional[torch.Tensor],
  65. ) -> torch.Tensor:
  66. weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
  67. return F.linear(input, weights, bias)
  68. # Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
  69. # the generic pytorch version.
  70. # Just visual comparison.
  71. def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
  72. n = int(parts.sum().item())
  73. device = torch.device('cuda:0')
  74. code_range = (1 << bits) // 2
  75. ingroups = 8
  76. codes = torch.randint(-code_range,
  77. code_range,
  78. size=(n, k // ingroups, nbooks),
  79. dtype=get_int_dtype(bits),
  80. device=device)
  81. codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
  82. dtype=torch.float16,
  83. device=device)
  84. count = 0
  85. for index in range(16):
  86. for i in range(8):
  87. for book in range(nbooks):
  88. codebooks[book, index, 0, i] = count * (10**book)
  89. count += 1
  90. print("codes shape", codes.shape)
  91. for i in range(16):
  92. for book in range(nbooks):
  93. codes[0, i, book] = i
  94. codes[0, -i, book] = i
  95. weights = dequantize_weight(codes, codebooks, None)
  96. weights2 = ops.aqlm_dequant(codes, codebooks, parts)
  97. print("weights shape:", weights.shape)
  98. print("weights2 shape:", weights2.shape)
  99. print("weights are:", weights)
  100. print("weights2 are:", weights2)
  101. print("first 128 weights are", weights[0, 0:128].to(torch.int32))
  102. print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
  103. print("last 128 weights are", weights[0, -128:])
  104. print("last 128 weights2 are:", weights2[0, -128:])
  105. def main():
  106. parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
  107. # Add arguments
  108. parser.add_argument("--nbooks",
  109. type=int,
  110. default=1,
  111. help="Number of codebooks (default: 1)")
  112. parser.add_argument("--bits",
  113. type=int,
  114. default=16,
  115. help="Number of bits per code element (default: 16)")
  116. parser.add_argument(
  117. "--test",
  118. type=bool,
  119. default=False,
  120. help="Run the decompression/dequant tester rather than benchmarking "
  121. "(default: False)")
  122. # Parse the arguments
  123. args = parser.parse_args()
  124. # Extract values
  125. nbooks = args.nbooks
  126. bits = args.bits
  127. if args.test:
  128. dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
  129. return
  130. # Otherwise, benchmark.
  131. methods = [
  132. ops.aqlm_gemm,
  133. dequant_out_scale,
  134. generic_dequantize_gemm,
  135. optimized_dequantize_gemm,
  136. dequant_weight_scale,
  137. torch_mult,
  138. dequant_no_scale,
  139. ]
  140. filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
  141. print(f"writing benchmarks to file {filename}")
  142. with open(filename, "w") as f:
  143. sys.stdout = f
  144. print('m | k | n | n parts', end='')
  145. for method in methods:
  146. print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
  147. print('')
  148. # These are reasonable prefill sizes.
  149. ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
  150. (4096, (11008, 11008)), (11008, (4096, )))
  151. # reasonable ranges for m.
  152. for m in [
  153. 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
  154. 128, 256, 512, 1024, 1536, 2048, 3072, 4096
  155. ]:
  156. print(f'{m}', file=sys.__stdout__)
  157. for ksp in ksandpartions:
  158. run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
  159. methods)
  160. sys.stdout = sys.__stdout__
  161. def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
  162. methods):
  163. # I didn't see visible improvements from increasing these, but feel free :)
  164. num_warmup_trials = 1
  165. num_trials = 1
  166. num_calls = 100
  167. # warmup.
  168. for method in methods:
  169. for _ in range(num_warmup_trials):
  170. run_timing(
  171. num_calls=num_calls,
  172. m=m,
  173. k=k,
  174. parts=parts,
  175. nbooks=nbooks,
  176. bits=bits,
  177. method=method,
  178. )
  179. n = parts.sum().item()
  180. print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
  181. for method in methods:
  182. best_time_us = 1e20
  183. for _ in range(num_trials):
  184. kernel_dur_ms = run_timing(
  185. num_calls=num_calls,
  186. m=m,
  187. k=k,
  188. parts=parts,
  189. nbooks=nbooks,
  190. bits=bits,
  191. method=method,
  192. )
  193. kernel_dur_us = 1000 * kernel_dur_ms
  194. if kernel_dur_us < best_time_us:
  195. best_time_us = kernel_dur_us
  196. print(f' | {kernel_dur_us:.0f}', end='')
  197. print('')
  198. def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
  199. nbooks: int, bits: int, method) -> float:
  200. n = int(parts.sum().item())
  201. device = torch.device('cuda:0')
  202. input = torch.randn((1, m, k), dtype=torch.float16, device=device)
  203. code_range = (1 << bits) // 2
  204. ingroups = 8
  205. codes = torch.randint(-code_range,
  206. code_range,
  207. size=(n, k // ingroups, nbooks),
  208. dtype=get_int_dtype(bits),
  209. device=device)
  210. codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
  211. dtype=torch.float16,
  212. device=device)
  213. scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
  214. # for comparison to just a pytorch mult.
  215. weights = torch.randn((n, k), dtype=torch.float16, device=device)
  216. start_event = torch.cuda.Event(enable_timing=True)
  217. end_event = torch.cuda.Event(enable_timing=True)
  218. start_event.record()
  219. if method is torch_mult:
  220. for i in range(num_calls):
  221. torch_mult(input, weights, scales)
  222. else:
  223. for i in range(num_calls):
  224. method(input, codes, codebooks, scales, parts, None)
  225. end_event.record()
  226. end_event.synchronize()
  227. dur_ms = start_event.elapsed_time(end_event) / num_calls
  228. return dur_ms
  229. if __name__ == "__main__":
  230. sys.exit(main())