marlin_utils_fp8.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from typing import Optional
  2. import torch
  3. import aphrodite._custom_ops as ops
  4. from aphrodite.common.utils import print_warning_once
  5. from aphrodite.platforms import current_platform
  6. from .marlin_utils import marlin_make_workspace, marlin_permute_scales
  7. def is_fp8_marlin_supported():
  8. capability = current_platform.get_device_capability()
  9. return capability[0] >= 8
  10. def apply_fp8_marlin_linear(
  11. input: torch.Tensor,
  12. weight: torch.Tensor,
  13. weight_scale: torch.Tensor,
  14. workspace: torch.Tensor,
  15. size_n: int,
  16. size_k: int,
  17. bias: Optional[torch.Tensor],
  18. ) -> torch.Tensor:
  19. # For GPUs that lack FP8 hardware support, we can leverage the
  20. # Marlin kernel for fast weight-only FP8 quantization
  21. reshaped_x = input.reshape(-1, input.shape[-1])
  22. out_shape = input.shape[:-1] + (size_n, )
  23. output = ops.fp8_marlin_gemm(
  24. a=reshaped_x,
  25. b_q_weight=weight,
  26. b_scales=weight_scale,
  27. workspace=workspace,
  28. num_bits=8,
  29. size_m=reshaped_x.shape[0],
  30. size_n=size_n,
  31. size_k=size_k,
  32. )
  33. if bias is not None:
  34. output.add_(bias) # In-place add
  35. return output.reshape(out_shape)
  36. def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
  37. strategy: str = "tensor") -> None:
  38. print_warning_once(
  39. "Your GPU does not have native support for FP8 computation but "
  40. "FP8 quantization is being used. Weight-only FP8 compression will "
  41. "be used leveraging the Marlin kernel. This may degrade "
  42. "performance for compute-heavy workloads.")
  43. part_size_n = layer.output_size_per_partition
  44. part_size_k = layer.input_size_per_partition
  45. device = layer.weight.device
  46. # WORKSPACE
  47. layer.workspace = marlin_make_workspace(part_size_n, device)
  48. # WEIGHT
  49. # Repack weights to marlin format
  50. marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
  51. layer.weight),
  52. perm=torch.empty(0,
  53. dtype=torch.int,
  54. device=device),
  55. size_k=part_size_k,
  56. size_n=part_size_n,
  57. num_bits=8)
  58. layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
  59. # WEIGHT SCALES
  60. scales = layer.weight_scale.to(layer.orig_dtype)
  61. # Permute scales
  62. marlin_scales = marlin_permute_scales(s=scales,
  63. size_k=part_size_k,
  64. size_n=part_size_n,
  65. group_size=-1)
  66. layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
  67. def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
  68. """
  69. Repack FP8 weights to gptq format (packed int32 elements)
  70. """
  71. assert fp8_tensor.dtype == torch.float8_e4m3fn
  72. assert fp8_tensor.shape[0] % 4 == 0
  73. # Reshape to prepare for packing
  74. reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
  75. # Convert fp8 to uint8 (byte) representation
  76. byte_tensor = reshaped.view(torch.uint8)
  77. # Pack 4 uint8 values into one int32
  78. packed = (byte_tensor[:, 0].to(torch.int32) |
  79. (byte_tensor[:, 1].to(torch.int32) << 8) |
  80. (byte_tensor[:, 2].to(torch.int32) << 16) |
  81. (byte_tensor[:, 3].to(torch.int32) << 24))
  82. return packed.view(fp8_tensor.shape[0] // 4,
  83. *fp8_tensor.shape[1:]).contiguous()