marlin_utils.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from aphrodite import _custom_ops as ops
  4. from aphrodite.platforms import current_platform
  5. GPTQ_MARLIN_TILE = 16
  6. GPTQ_MARLIN_MIN_THREAD_N = 64
  7. GPTQ_MARLIN_MIN_THREAD_K = 128
  8. GPTQ_MARLIN_MAX_PARALLEL = 16
  9. GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
  10. GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  11. GPTQ_MARLIN_SUPPORTED_SYM = [True]
  12. GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1]
  13. def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
  14. min_capability: int) -> bool:
  15. # If the capability of the device is too low, cannot convert.
  16. major, minor = current_platform.get_device_capability()
  17. device_capability = major * 10 + minor
  18. if device_capability < min_capability:
  19. return False
  20. return (device_capability >= min_capability
  21. and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
  22. and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
  23. and is_sym in GPTQ_MARLIN_SUPPORTED_SYM)
  24. def verify_marlin_supported(num_bits: int, group_size: Optional[int],
  25. is_sym: bool) -> None:
  26. if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
  27. raise ValueError(
  28. f"Marlin does not support weight_bits = {num_bits}. "
  29. f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
  30. "are supported.")
  31. if (group_size is None
  32. or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES):
  33. raise ValueError(
  34. f"Marlin does not support group_size = {group_size}. "
  35. f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
  36. "are supported.")
  37. if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
  38. raise ValueError(
  39. f"Marlin does not support is_sym = is_sym. "
  40. f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
  41. def verify_marlin_supports_shape(output_size_per_partition: int,
  42. input_size_per_partition: int,
  43. input_size: int, group_size: int) -> None:
  44. # Validate output_size_per_partition
  45. if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
  46. raise ValueError(f"Weight output_size_per_partition = "
  47. f"{output_size_per_partition} is not divisible by "
  48. f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
  49. "Consider reducing tensor_parallel_size or running "
  50. "with --quantization gptq.")
  51. # Validate input_size_per_partition
  52. if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
  53. raise ValueError(f"Weight input_size_per_partition = "
  54. f"{input_size_per_partition} is not divisible "
  55. f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
  56. "Consider reducing tensor_parallel_size or running "
  57. "with --quantization gptq.")
  58. if (group_size < input_size
  59. and input_size_per_partition % group_size != 0):
  60. raise ValueError(
  61. f"Weight input_size_per_partition = {input_size_per_partition}"
  62. f" is not divisible by group_size = {group_size}."
  63. "Consider reducing tensor_parallel_size or running "
  64. "with --quantization gptq.")
  65. def marlin_make_workspace(output_size_per_partition: int,
  66. device: torch.device) -> torch.Tensor:
  67. max_workspace_size = (output_size_per_partition //
  68. GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
  69. return torch.zeros(max_workspace_size,
  70. dtype=torch.int,
  71. device=device,
  72. requires_grad=False)
  73. def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
  74. return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
  75. requires_grad=False)
  76. def marlin_sort_g_idx(
  77. g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  78. g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
  79. return g_idx[g_idx_sort_indices], g_idx_sort_indices
  80. def get_scale_perms():
  81. scale_perm: List[int] = []
  82. for i in range(8):
  83. scale_perm.extend([i + 8 * j for j in range(8)])
  84. scale_perm_single: List[int] = []
  85. for i in range(4):
  86. scale_perm_single.extend(
  87. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  88. return scale_perm, scale_perm_single
  89. def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
  90. group_size: int) -> torch.Tensor:
  91. scale_perm, scale_perm_single = get_scale_perms()
  92. if group_size < size_k and group_size != -1:
  93. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  94. else:
  95. s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
  96. s = s.reshape((-1, size_n)).contiguous()
  97. return s
  98. # Newly generated tensors need to replace existing tensors that are
  99. # already registered as parameters by vLLM (and won't be freed)
  100. def replace_tensor(layer: torch.nn.Module, name: str,
  101. new_t: torch.Tensor) -> None:
  102. # It is important to use resize_() here since it ensures
  103. # the same buffer is reused
  104. getattr(layer, name).resize_(new_t.shape)
  105. getattr(layer, name).copy_(new_t)
  106. del new_t
  107. def apply_marlin_linear(input: torch.Tensor,
  108. weight: torch.Tensor,
  109. weight_scale: torch.Tensor,
  110. g_idx: torch.Tensor,
  111. g_idx_sort_indices: torch.Tensor,
  112. workspace: torch.Tensor,
  113. num_bits: int,
  114. output_size_per_partition: int,
  115. input_size_per_partition: int,
  116. is_k_full: bool,
  117. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  118. reshaped_x = input.reshape(-1, input.shape[-1])
  119. out_shape = input.shape[:-1] + (output_size_per_partition, )
  120. output = ops.gptq_marlin_gemm(reshaped_x,
  121. weight,
  122. weight_scale,
  123. g_idx,
  124. g_idx_sort_indices,
  125. workspace,
  126. num_bits,
  127. size_m=reshaped_x.shape[0],
  128. size_n=output_size_per_partition,
  129. size_k=input_size_per_partition,
  130. is_k_full=is_k_full)
  131. if bias is not None:
  132. output.add_(bias) # In-place add
  133. return output.reshape(out_shape)