marlin_utils_fp8.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from typing import Optional
  2. import torch
  3. import aphrodite._custom_ops as ops
  4. from aphrodite.common.utils import print_warning_once
  5. from aphrodite.platforms import current_platform
  6. from .marlin_utils import marlin_make_workspace, marlin_permute_scales
  7. def is_fp8_marlin_supported():
  8. capability = current_platform.get_device_capability()
  9. return capability[0] >= 8
  10. def apply_fp8_marlin_linear(
  11. input: torch.Tensor,
  12. weight: torch.Tensor,
  13. weight_scale: torch.Tensor,
  14. workspace: torch.Tensor,
  15. size_n: int,
  16. size_k: int,
  17. bias: Optional[torch.Tensor],
  18. ) -> torch.Tensor:
  19. # For GPUs that lack FP8 hardware support, we can leverage the
  20. # Marlin kernel for fast weight-only FP8 quantization
  21. reshaped_x = input.reshape(-1, input.shape[-1])
  22. out_shape = input.shape[:-1] + (size_n, )
  23. output = ops.fp8_marlin_gemm(
  24. a=reshaped_x,
  25. b_q_weight=weight,
  26. b_scales=weight_scale,
  27. workspace=workspace,
  28. num_bits=8,
  29. size_m=reshaped_x.shape[0],
  30. size_n=size_n,
  31. size_k=size_k,
  32. )
  33. if bias is not None:
  34. output.add_(bias) # In-place add
  35. return output.reshape(out_shape)
  36. def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
  37. print_warning_once(
  38. "Your GPU does not have native support for FP8 computation but "
  39. "FP8 quantization is being used. Weight-only FP8 compression will "
  40. "be used leveraging the Marlin kernel. This may degrade "
  41. "performance for compute-heavy workloads.")
  42. part_size_n = layer.output_size_per_partition
  43. part_size_k = layer.input_size_per_partition
  44. device = layer.weight.device
  45. # WORKSPACE
  46. layer.workspace = marlin_make_workspace(part_size_n, device)
  47. # WEIGHT
  48. # Repack weights to marlin format
  49. marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
  50. layer.weight),
  51. perm=torch.empty(0,
  52. dtype=torch.int,
  53. device=device),
  54. size_k=part_size_k,
  55. size_n=part_size_n,
  56. num_bits=8)
  57. layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
  58. # WEIGHT SCALES
  59. # Currently Marlin doesn't support per-tensor scales, so we
  60. # expand it to channelwise
  61. is_channelwise = (len(layer.weight_scale.shape) > 0
  62. and layer.weight_scale.shape[0] == part_size_n)
  63. if is_channelwise:
  64. scales = layer.weight_scale
  65. else:
  66. scales = layer.weight_scale.repeat(1, part_size_n)
  67. scales = scales.to(layer.orig_dtype).to(device)
  68. # Permute scales
  69. marlin_scales = marlin_permute_scales(s=scales,
  70. size_k=part_size_k,
  71. size_n=part_size_n,
  72. group_size=-1)
  73. layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
  74. def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
  75. """
  76. Repack FP8 weights to gptq format (packed int32 elements)
  77. """
  78. assert fp8_tensor.dtype == torch.float8_e4m3fn
  79. assert fp8_tensor.shape[0] % 4 == 0
  80. # Reshape to prepare for packing
  81. reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
  82. # Convert fp8 to uint8 (byte) representation
  83. byte_tensor = reshaped.view(torch.uint8)
  84. # Pack 4 uint8 values into one int32
  85. packed = (byte_tensor[:, 0].to(torch.int32) |
  86. (byte_tensor[:, 1].to(torch.int32) << 8) |
  87. (byte_tensor[:, 2].to(torch.int32) << 16) |
  88. (byte_tensor[:, 3].to(torch.int32) << 24))
  89. return packed.view(fp8_tensor.shape[0] // 4,
  90. *fp8_tensor.shape[1:]).contiguous()