quip_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import math
  2. from pathlib import Path
  3. import scipy
  4. import torch
  5. from safetensors.torch import load_file
  6. try:
  7. import aphrodite._hadamard_C as hadamard_C
  8. except ImportError:
  9. pass
  10. HADA_TENSORS = load_file(
  11. Path(__file__).resolve().parent / "hadamard.safetensors")
  12. class HadamardTransformFn(torch.autograd.Function):
  13. @staticmethod
  14. def forward(ctx, x, scale=1.0):
  15. ctx._hadamard_transform_scale = scale # pylint: disable=protected-access
  16. return hadamard_C.hadamard_transform(x, scale=scale)
  17. def hadamard_transform(x, scale=1.0):
  18. return HadamardTransformFn.apply(x, scale)
  19. def int2mask(i, int_map):
  20. return ((i & int_map) > 0).int()
  21. def mask2int(mask, int_map):
  22. return (int_map.unsqueeze(0) * mask.int()).sum(dim=-1)
  23. def get_norm12():
  24. # 29 elements of norm 12 in E8 + 1/4
  25. return torch.tensor([
  26. [3, 1, 1, 1, 3, 3, 3, 3],
  27. [1, 3, 1, 1, 3, 3, 3, 3],
  28. [1, 1, 3, 1, 3, 3, 3, 3],
  29. [1, 1, 1, 3, 3, 3, 3, 3],
  30. [3, 3, 3, 1, 3, 3, 1, 1],
  31. [3, 3, 3, 1, 3, 1, 3, 1],
  32. [3, 3, 3, 1, 1, 3, 3, 1],
  33. [3, 3, 3, 1, 3, 1, 1, 3],
  34. [3, 3, 3, 1, 1, 3, 1, 3],
  35. [3, 3, 3, 1, 1, 1, 3, 3],
  36. [3, 3, 1, 3, 3, 3, 1, 1],
  37. [3, 3, 1, 3, 3, 1, 3, 1],
  38. [3, 3, 1, 3, 1, 3, 3, 1],
  39. [3, 3, 1, 3, 3, 1, 1, 3],
  40. [3, 3, 1, 3, 1, 3, 1, 3],
  41. [3, 3, 1, 3, 1, 1, 3, 3],
  42. [3, 1, 3, 3, 3, 3, 1, 1],
  43. [3, 1, 3, 3, 3, 1, 3, 1],
  44. [3, 1, 3, 3, 1, 3, 3, 1],
  45. [3, 1, 3, 3, 3, 1, 1, 3],
  46. [3, 1, 3, 3, 1, 3, 1, 3],
  47. [1, 3, 3, 3, 1, 1, 3, 3],
  48. [1, 3, 3, 3, 3, 3, 1, 1],
  49. [1, 3, 3, 3, 3, 1, 3, 1],
  50. [1, 3, 3, 3, 1, 3, 3, 1],
  51. [1, 3, 3, 3, 3, 1, 1, 3],
  52. [1, 3, 3, 3, 1, 3, 1, 3],
  53. [1, 1, 3, 3, 1, 3, 3, 3],
  54. [3, 3, 1, 1, 3, 3, 3, 1],
  55. ]) / 2
  56. def get_packed_abs_grid():
  57. intr = torch.arange(-4, 4)
  58. d8 = torch.cartesian_prod(*[intr] * 8).float() + 1 / 2
  59. d8m2 = d8.sum(dim=-1) % 2 == 0
  60. d8n = d8.norm(dim=-1)**2 <= 10
  61. d8abs = torch.unique(d8[sorted(torch.where(d8m2 * d8n)[0])].abs(), dim=0)
  62. norm12 = get_norm12()
  63. cba = torch.concat([d8abs, norm12], dim=0)
  64. cba = cba[:, [0, 2, 1, 3, 4, 6, 5, 7]]
  65. cba[:, 7] *= (1 - 2 * (cba.sum(1) % 2))
  66. cba = cba * 4
  67. cba = cba.to(torch.int64)
  68. acc = cba[:, 0]
  69. for i in range(7):
  70. acc = acc | (cba[:, (i + 1)] << ((i + 1) * 8))
  71. return acc
  72. def next_power_of_2(n):
  73. if n == 0:
  74. return 1
  75. return 2**math.ceil(math.log(n, 2))
  76. def get_power_of_2(n):
  77. """Returns the highest power of 2 that divides n."""
  78. k = 0
  79. while n % 2 == 0:
  80. n //= 2
  81. k += 1
  82. return k, n
  83. def get_hadK(n, use_rand=True):
  84. exp, base = get_power_of_2(n)
  85. if base == 1:
  86. return None, 1, n
  87. if use_rand:
  88. rand_mat = torch.tensor(scipy.stats.special_ortho_group.rvs(base)).to(
  89. torch.float32)
  90. return rand_mat, base, n
  91. # Use hadamad only and add padding if cannot find one
  92. pad_n = next_power_of_2(n)
  93. if exp < 2 or str(base * 4) not in HADA_TENSORS:
  94. return None, 1, pad_n
  95. base_mat = HADA_TENSORS[str(base * 4)] / math.sqrt(base * 4)
  96. return base_mat, base * 4, n
  97. def matmul_hadU_cuda(X, hadK, K, n, scale=None, transpose=False):
  98. if n != X.shape[-1]:
  99. X = torch.nn.functional.pad(X, (0, n - X.shape[-1]))
  100. had_scale = 1 / math.sqrt(n // K) if scale is None else scale / math.sqrt(
  101. n // K)
  102. if K == 1:
  103. return hadamard_transform(X.contiguous(), scale=had_scale)
  104. if transpose:
  105. hadK = hadK.T.contiguous()
  106. input = X.view(-1, K, n // K) # pylint: disable=redefined-builtin
  107. input = hadamard_transform(input.contiguous(), scale=had_scale)
  108. input = hadK @ input
  109. return input.reshape(X.shape)
  110. def matmul_hadUt_cuda(X, hadK, K, n, scale=None):
  111. return matmul_hadU_cuda(X, hadK, K, n, scale=scale, transpose=True)