1
0

quip_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import math
  2. from pathlib import Path
  3. import scipy
  4. import torch
  5. from safetensors.torch import load_file
  6. from contextlib import suppress
  7. with suppress(ImportError):
  8. import aphrodite._hadamard_C as hadamard_C
  9. HADA_TENSORS = load_file(
  10. Path(__file__).resolve().parent / "hadamard.safetensors")
  11. class HadamardTransformFn(torch.autograd.Function):
  12. @staticmethod
  13. def forward(ctx, x, scale=1.0):
  14. ctx._hadamard_transform_scale = scale # pylint: disable=protected-access
  15. return hadamard_C.fast_hadamard_transform(x, scale)
  16. def hadamard_transform(x, scale=1.0):
  17. return HadamardTransformFn.apply(x, scale)
  18. def int2mask(i, int_map):
  19. return ((i & int_map) > 0).int()
  20. def mask2int(mask, int_map):
  21. return (int_map.unsqueeze(0) * mask.int()).sum(dim=-1)
  22. def get_norm12():
  23. # 29 elements of norm 12 in E8 + 1/4
  24. return torch.tensor([
  25. [3, 1, 1, 1, 3, 3, 3, 3],
  26. [1, 3, 1, 1, 3, 3, 3, 3],
  27. [1, 1, 3, 1, 3, 3, 3, 3],
  28. [1, 1, 1, 3, 3, 3, 3, 3],
  29. [3, 3, 3, 1, 3, 3, 1, 1],
  30. [3, 3, 3, 1, 3, 1, 3, 1],
  31. [3, 3, 3, 1, 1, 3, 3, 1],
  32. [3, 3, 3, 1, 3, 1, 1, 3],
  33. [3, 3, 3, 1, 1, 3, 1, 3],
  34. [3, 3, 3, 1, 1, 1, 3, 3],
  35. [3, 3, 1, 3, 3, 3, 1, 1],
  36. [3, 3, 1, 3, 3, 1, 3, 1],
  37. [3, 3, 1, 3, 1, 3, 3, 1],
  38. [3, 3, 1, 3, 3, 1, 1, 3],
  39. [3, 3, 1, 3, 1, 3, 1, 3],
  40. [3, 3, 1, 3, 1, 1, 3, 3],
  41. [3, 1, 3, 3, 3, 3, 1, 1],
  42. [3, 1, 3, 3, 3, 1, 3, 1],
  43. [3, 1, 3, 3, 1, 3, 3, 1],
  44. [3, 1, 3, 3, 3, 1, 1, 3],
  45. [3, 1, 3, 3, 1, 3, 1, 3],
  46. [1, 3, 3, 3, 1, 1, 3, 3],
  47. [1, 3, 3, 3, 3, 3, 1, 1],
  48. [1, 3, 3, 3, 3, 1, 3, 1],
  49. [1, 3, 3, 3, 1, 3, 3, 1],
  50. [1, 3, 3, 3, 3, 1, 1, 3],
  51. [1, 3, 3, 3, 1, 3, 1, 3],
  52. [1, 1, 3, 3, 1, 3, 3, 3],
  53. [3, 3, 1, 1, 3, 3, 3, 1],
  54. ]) / 2
  55. def get_packed_abs_grid():
  56. intr = torch.arange(-4, 4)
  57. d8 = torch.cartesian_prod(*[intr] * 8).float() + 1 / 2
  58. d8m2 = d8.sum(dim=-1) % 2 == 0
  59. d8n = d8.norm(dim=-1)**2 <= 10
  60. d8abs = torch.unique(d8[sorted(torch.where(d8m2 * d8n)[0])].abs(), dim=0)
  61. norm12 = get_norm12()
  62. cba = torch.concat([d8abs, norm12], dim=0)
  63. cba = cba[:, [0, 2, 1, 3, 4, 6, 5, 7]]
  64. cba[:, 7] *= (1 - 2 * (cba.sum(1) % 2))
  65. cba = cba * 4
  66. cba = cba.to(torch.int64)
  67. acc = cba[:, 0]
  68. for i in range(7):
  69. acc = acc | (cba[:, (i + 1)] << ((i + 1) * 8))
  70. return acc
  71. def next_power_of_2(n):
  72. if n == 0:
  73. return 1
  74. return 2**math.ceil(math.log(n, 2))
  75. def get_power_of_2(n):
  76. """Returns the highest power of 2 that divides n."""
  77. k = 0
  78. while n % 2 == 0:
  79. n //= 2
  80. k += 1
  81. return k, n
  82. def get_hadK(n, use_rand=True):
  83. exp, base = get_power_of_2(n)
  84. if base == 1:
  85. return None, 1, n
  86. if use_rand:
  87. rand_mat = torch.tensor(scipy.stats.special_ortho_group.rvs(base)).to(
  88. torch.float32)
  89. return rand_mat, base, n
  90. # Use hadamad only and add padding if cannot find one
  91. pad_n = next_power_of_2(n)
  92. if exp < 2 or str(base * 4) not in HADA_TENSORS:
  93. return None, 1, pad_n
  94. base_mat = HADA_TENSORS[str(base * 4)] / math.sqrt(base * 4)
  95. return base_mat, base * 4, n
  96. def matmul_hadU_cuda(X, hadK, K, n, scale=None, transpose=False):
  97. if n != X.shape[-1]:
  98. X = torch.nn.functional.pad(X, (0, n - X.shape[-1]))
  99. had_scale = 1 / math.sqrt(n // K) if scale is None else scale / math.sqrt(
  100. n // K)
  101. if K == 1:
  102. return hadamard_transform(X.contiguous(), scale=had_scale)
  103. if transpose:
  104. hadK = hadK.T.contiguous()
  105. input = X.view(-1, K, n // K) # pylint: disable=redefined-builtin
  106. input = hadamard_transform(input.contiguous(), scale=had_scale)
  107. input = hadK @ input
  108. return input.reshape(X.shape)
  109. def matmul_hadUt_cuda(X, hadK, K, n, scale=None):
  110. return matmul_hadU_cuda(X, hadK, K, n, scale=scale, transpose=True)