machete_utils.py 1.0 KB

123456789101112131415161718192021222324252627282930
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from aphrodite.scalar_type import ScalarType, scalar_types
  4. MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
  5. MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
  6. def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
  7. if zero_points:
  8. return [scalar_types.uint4, scalar_types.uint8]
  9. else:
  10. return [scalar_types.uint4b8, scalar_types.uint8b128]
  11. def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
  12. return [torch.float16, torch.bfloat16]
  13. def check_machete_supports_shape(in_features: int, out_featrues: int) \
  14. -> Tuple[bool, Optional[str]]:
  15. if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
  16. return False, "Input features size must be divisible by "\
  17. f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
  18. if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0:
  19. return False, "Output features size must be divisible by "\
  20. f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}"
  21. return True, None