marlin_utils_test_24.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. """Utility functions used for tests and benchmarks"""
  2. import random
  3. from typing import List
  4. import numpy
  5. import torch
  6. from .marlin_utils_test import marlin_weights
  7. from .quant_utils import quantize_weights
  8. # This is PyTorch implementation of main part of reorder_meta()
  9. # function, from tools/util/include/cutlass/util/host_reorder.h file
  10. # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
  11. # GEMM decides upon layout of this matrix, and at the moment for the
  12. # sparse GEMM executed on tensor cores, this is layout described by
  13. # ColumnMajorInterleaved<2> data structure, in
  14. # include/cutlass/layout/matrix.h of CUTLASS source tree. The
  15. # reordering of meta matrix into meta_reordered matrix calculated
  16. # according to these segments of CUTLASS code is re-implemented here.
  17. # Note that this calculation produces offsets for scattering metadata
  18. # matrix elements into reordered metadata matrix elements (or,
  19. # equivalently, for gathering reordered metadata matrix element back
  20. # into metadata matrix elements).
  21. def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype,
  22. device):
  23. dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
  24. dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
  25. # Reorder the rows, then swizzle the 2x2 blocks.
  26. group_x = 64
  27. group_y = 32 if meta_dtype.itemsize == 2 else 16
  28. dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 +
  29. (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 +
  30. ((dst_rows % group_x) // 8) * 4)
  31. topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
  32. bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
  33. dst_rows += topright - bottomleft
  34. dst_cols -= topright - bottomleft
  35. # Assumed that meta tensor is to be stored in CUTLASS
  36. # InterleavedColumnMajor layout, and reverse engineered
  37. # corresponding code to store values into this tensor.
  38. interleave = 2
  39. cols_maj = dst_cols // interleave
  40. cols_min = dst_cols % interleave
  41. return (cols_maj * m * interleave + dst_rows * interleave +
  42. cols_min).view(-1)
  43. # This function converts dense matrix into sparse semi-structured
  44. # representation, producing "compressed" matrix, in the layout used by
  45. # CUTLASS backend, and corresponding metadata matrix.
  46. def sparse_semi_structured_from_dense_cutlass(dense):
  47. if dense.dim() != 2:
  48. raise RuntimeError(
  49. f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
  50. )
  51. m, k = dense.shape
  52. device = dense.device
  53. meta_dtype = torch.int8
  54. if dense.dtype == torch.int8:
  55. meta_dtype = torch.int32
  56. elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
  57. meta_dtype = torch.int16
  58. else:
  59. raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
  60. quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
  61. if quadbits_per_meta_elem not in (4, 8):
  62. raise RuntimeError(
  63. "Invalid number of elements per meta element calculated")
  64. if meta_dtype == torch.int32:
  65. if m % 16 != 0:
  66. raise RuntimeError(
  67. f"Number of rows of dense matrix {m} must be divisible by 16")
  68. else:
  69. if m % 32 != 0:
  70. raise RuntimeError(
  71. f"Number of rows of dense matrix {m} must be divisible by 32")
  72. if k % (4 * quadbits_per_meta_elem) != 0:
  73. raise RuntimeError(
  74. f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
  75. )
  76. if dense.dtype != torch.float:
  77. ksparse = 4
  78. dense_4 = dense.view(-1, k // ksparse, ksparse)
  79. m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
  80. else:
  81. ksparse = 2
  82. dense_2 = dense.view(-1, k // ksparse, ksparse)
  83. m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
  84. meta_ncols = k // (ksparse * quadbits_per_meta_elem)
  85. # Encoding quadruples of True/False values as follows:
  86. # [True, True, False, False] -> 0b0100
  87. # [True, False, True, False] -> 0b1000
  88. # [False, True, True, False] -> 0b1001
  89. # [True, False, False, True ] -> 0b1100
  90. # [False, True, False, True ] -> 0b1101
  91. # [False, False, True, True ] -> 0b1110
  92. # Thus, lower two bits in the encoding are index of the True value
  93. # at the lowest index in the quadruple, and the higher two bits in
  94. # the encoding are index of the other True value in the quadruple.
  95. # In case there are less than two True values, than False value or
  96. # values at some index or indices are considered True for the
  97. # encoding. In case there are more than two True values, then the
  98. # excess True value(s) at some indices are considered False for
  99. # the encoding. The exact encodings used for these cases are as
  100. # follows:
  101. # [False, False, False, False] -> 0b1110
  102. # [False, False, False, True ] -> 0b1110
  103. # [False, False, True, False] -> 0b1110
  104. # [False, True, False, False] -> 0b1001
  105. # [False, True, True, True ] -> 0b1101
  106. # [True, False, False, False] -> 0b1000
  107. # [True, False, True, True ] -> 0b1100
  108. # [True, True, False, True ] -> 0b0100
  109. # [True, True, True, False] -> 0b0100
  110. # [True, True, True, True ] -> 0b0100
  111. # These particular encodings are chosen, with the help of Espresso
  112. # logic minimizer software, for the purpose of minimization of
  113. # corresponding Boolean functions, that translate non-zero flags
  114. # into encoding bits. Note also possible choices for the first
  115. # and last of these encodings were limited only to (0b0100,
  116. # 0b1110), in order to produce valid encodings for 1:2 sparsity
  117. # case.
  118. expr0 = m0 & m1
  119. expr1 = ~m0 & m1
  120. expr2 = ~m0 & ~m1
  121. bit0 = expr1
  122. bit1 = expr2
  123. bit2 = expr0 | expr2 | m3
  124. bit3 = expr1 | ~m1
  125. idxs0 = bit0 | (bit1.to(torch.int64) << 1)
  126. idxs1 = bit2 | (bit3.to(torch.int64) << 1)
  127. if dense.dtype != torch.float:
  128. sparse0 = dense_4.gather(
  129. -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
  130. sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
  131. sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
  132. else:
  133. sparse = dense_2.gather(-1,
  134. idxs0.unsqueeze(-1) // 2).view(
  135. m,
  136. k // 2) # type: ignore[possibly-undefined]
  137. meta_4 = idxs0 | (idxs1 << 2)
  138. meta_n = meta_4.view(
  139. (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
  140. if quadbits_per_meta_elem == 4:
  141. meta = (meta_n[:, :, 0]
  142. | (meta_n[:, :, 1] << 4)
  143. | (meta_n[:, :, 2] << 8)
  144. | (meta_n[:, :, 3] << 12))
  145. elif quadbits_per_meta_elem == 8:
  146. meta = (meta_n[:, :, 0]
  147. | (meta_n[:, :, 1] << 4)
  148. | (meta_n[:, :, 2] << 8)
  149. | (meta_n[:, :, 3] << 12)
  150. | (meta_n[:, :, 4] << 16)
  151. | (meta_n[:, :, 5] << 20)
  152. | (meta_n[:, :, 6] << 24)
  153. | (meta_n[:, :, 7] << 28))
  154. # Reorder meta tensor elements.
  155. meta_reordered = meta.new_empty(
  156. (m * meta_ncols, )) # type: ignore[possibly-undefined]
  157. meta_offsets = _calculate_meta_reordering_scatter_offsets(
  158. m, meta_ncols, meta_dtype, device)
  159. meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
  160. return (sparse, meta_reordered.view(m, meta_ncols))
  161. # This function performs reverse of the function above - it
  162. # reconstructs dense matrix from a pair of "compressed" matrix, given
  163. # in the layout used by CUTLASS backend, and accompanying metadata
  164. # matrix.
  165. def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
  166. if sparse.dim() != 2:
  167. raise RuntimeError(
  168. f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
  169. )
  170. m, k = sparse.shape
  171. device = sparse.device
  172. if meta_reordered.dim() != 2:
  173. raise RuntimeError(
  174. f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
  175. )
  176. if meta_reordered.device != device:
  177. raise RuntimeError(
  178. f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
  179. )
  180. meta_dtype = meta_reordered.dtype
  181. if meta_dtype not in (torch.int16, torch.int32):
  182. raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
  183. quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
  184. ksparse = 4 if sparse.dtype != torch.float else 2
  185. meta_nrows, meta_ncols = meta_reordered.shape
  186. if meta_nrows != m:
  187. raise RuntimeError(
  188. f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
  189. )
  190. if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
  191. raise RuntimeError(
  192. f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
  193. "expected according to the number of columns of meta matrix")
  194. # Undo meta tensor elements reordering.
  195. meta_offsets = _calculate_meta_reordering_scatter_offsets(
  196. m, meta_ncols, meta_dtype, device)
  197. meta = torch.gather(meta_reordered.view(-1), 0,
  198. meta_offsets).view(m, meta_ncols)
  199. # Unpack sparse tensor back to original dense tensor, using
  200. # information provided by meta tensor. Note that torch.float
  201. # datatype is handled pretty much the same as
  202. # torch.half/torch.bfloat16, as metadata for a pair of torch.float
  203. # value is encoded as if underlying 8 bytes contain four
  204. # torch.half/torch.bfloat16 values, where either first two or last
  205. # two are zeros.
  206. meta_2 = torch.empty(
  207. (m, meta_ncols, 2 * quadbits_per_meta_elem),
  208. dtype=meta_dtype,
  209. device=device,
  210. )
  211. if quadbits_per_meta_elem == 4:
  212. meta_2[:, :, 0] = meta & 0b11
  213. meta_2[:, :, 1] = (meta >> 2) & 0b11
  214. meta_2[:, :, 2] = (meta >> 4) & 0b11
  215. meta_2[:, :, 3] = (meta >> 6) & 0b11
  216. meta_2[:, :, 4] = (meta >> 8) & 0b11
  217. meta_2[:, :, 5] = (meta >> 10) & 0b11
  218. meta_2[:, :, 6] = (meta >> 12) & 0b11
  219. meta_2[:, :, 7] = (meta >> 14) & 0b11
  220. elif quadbits_per_meta_elem == 8:
  221. meta_2[:, :, 0] = meta & 0b11
  222. meta_2[:, :, 1] = (meta >> 2) & 0b11
  223. meta_2[:, :, 2] = (meta >> 4) & 0b11
  224. meta_2[:, :, 3] = (meta >> 6) & 0b11
  225. meta_2[:, :, 4] = (meta >> 8) & 0b11
  226. meta_2[:, :, 5] = (meta >> 10) & 0b11
  227. meta_2[:, :, 6] = (meta >> 12) & 0b11
  228. meta_2[:, :, 7] = (meta >> 14) & 0b11
  229. meta_2[:, :, 8] = (meta >> 16) & 0b11
  230. meta_2[:, :, 9] = (meta >> 18) & 0b11
  231. meta_2[:, :, 10] = (meta >> 20) & 0b11
  232. meta_2[:, :, 11] = (meta >> 22) & 0b11
  233. meta_2[:, :, 12] = (meta >> 24) & 0b11
  234. meta_2[:, :, 13] = (meta >> 26) & 0b11
  235. meta_2[:, :, 14] = (meta >> 28) & 0b11
  236. meta_2[:, :, 15] = (meta >> 30) & 0b11
  237. dense_offsets = meta_2.view(-1) + (
  238. torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view(
  239. -1, 1).repeat(1, 2).view(-1)
  240. dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device)
  241. if sparse.dtype != torch.float:
  242. # dense.scatter_(0, dense_offsets, sparse.view(-1))
  243. dense.scatter_(0, dense_offsets, sparse.reshape(-1))
  244. else:
  245. dense.view(torch.half).scatter_(0, dense_offsets,
  246. sparse.view(torch.half).view(-1))
  247. return dense.view(m, 2 * k)
  248. def mask_creator(tensor):
  249. """
  250. Class for creating N:M sparsity masks.
  251. Masks will be created using the N:M ratio, where for every block of
  252. M weights, N will be pruned based on ranked weight value. Each mask
  253. will correspond to the given tensor.
  254. :param N: The number of weights in a group to keep
  255. :param M: The size of a weight group
  256. """
  257. N = 2
  258. M = 4
  259. mask = None
  260. # for i, tensor in enumerate(tensors):
  261. if tensor.numel() % M != 0:
  262. raise ValueError(
  263. f"Tensor of size {tensor.shape} can't be evenly divided into "
  264. f"{M} groups")
  265. num_groups = tensor.numel() // M
  266. # N:M sparsity for linear layers
  267. tensor_temp = tensor.detach().abs().reshape(num_groups, M)
  268. index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)]
  269. w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
  270. mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
  271. return mask
  272. def inject_24(w, size_k, size_n):
  273. assert w.shape == (size_k, size_n)
  274. mask = mask_creator(w.t()).t().cuda().bool()
  275. return (mask * w).contiguous(), mask.contiguous()
  276. def check_24(w, num_rows_to_sample=50, _verbose=False):
  277. BLOCK_SIZE = 4
  278. MAX_NON_ZEROS = 2
  279. w = w.t().contiguous()
  280. print("check_24: w.shape = {}".format(w.shape))
  281. num_rows, num_cols = w.shape
  282. sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
  283. if _verbose:
  284. print(f"Sampled row idxs = {sampled_row_idxs}")
  285. total_segments = 0
  286. non_24_segments = 0
  287. for i in sampled_row_idxs:
  288. for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
  289. total_segments += 1
  290. block = w[i, j:j + BLOCK_SIZE]
  291. num_nonzero = torch.count_nonzero(block)
  292. if num_nonzero > MAX_NON_ZEROS:
  293. print("i = {} j = {} block = {}".format(i, j, block))
  294. non_24_segments += 1
  295. print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
  296. def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
  297. assert q_24.shape == (size_k, size_n)
  298. # Remove zp to normalize over 0
  299. max_q_val = (1 << num_bits) - 1
  300. zp = (max_q_val + 1) // 2
  301. q_24_no_zp = q_24 - zp
  302. # Compress
  303. q_24_no_zp = q_24_no_zp.t().contiguous()
  304. q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
  305. q_24_no_zp)
  306. q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
  307. # Restore zp
  308. q_24_comp = q_24_no_zp_comp + zp
  309. # Resize meta to its actual shape (without moving any data)
  310. meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
  311. return q_24_comp, meta
  312. def get_scale_perms_24():
  313. scale_perm: List[int] = []
  314. for i in range(8):
  315. scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
  316. scale_perm_single: List[int] = []
  317. for i in range(8):
  318. scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
  319. return scale_perm, scale_perm_single
  320. def get_weight_perm_24(num_bits: int):
  321. perm_list: List[int] = []
  322. for i in range(32):
  323. perm1: List[int] = []
  324. col = i // 4
  325. col_o = col // 2
  326. for block in [0, 1]:
  327. for row in [
  328. 2 * (i % 4),
  329. 2 * (i % 4) + 1,
  330. 2 * (i % 4 + 4),
  331. 2 * (i % 4 + 4) + 1,
  332. ]:
  333. perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
  334. 4 * block)
  335. for j in range(4):
  336. perm_list.extend([p + 1 * j for p in perm1])
  337. perm = numpy.array(perm_list)
  338. if num_bits == 4:
  339. interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
  340. elif num_bits == 8:
  341. interleave = numpy.array([0, 2, 1, 3])
  342. else:
  343. raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
  344. perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
  345. perm = torch.from_numpy(perm)
  346. return perm
  347. def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
  348. group_size: int) -> torch.Tensor:
  349. scale_perm, scale_perm_single = get_scale_perms_24()
  350. if group_size < size_k and group_size != -1:
  351. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  352. else:
  353. s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
  354. s = s.reshape((-1, size_n)).contiguous()
  355. return s
  356. def marlin_24_quantize(
  357. w: torch.Tensor,
  358. num_bits: int,
  359. group_size: int,
  360. ):
  361. size_k, size_n = w.shape
  362. # Normalize group_size
  363. if group_size == -1:
  364. group_size = size_k
  365. assert group_size <= size_k
  366. # Inject 2:4 sparsity
  367. w_24, mask_24 = inject_24(w, size_k, size_n)
  368. # Quantize
  369. w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
  370. num_bits,
  371. group_size,
  372. act_order=False)
  373. # Compress quantized weight
  374. q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
  375. num_bits)
  376. size_k_comp = size_k // 2
  377. # Reformat to marlin
  378. weight_perm = get_weight_perm_24(num_bits)
  379. marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
  380. num_bits, weight_perm)
  381. marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
  382. # Create result
  383. res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
  384. for i in range(len(res_list)):
  385. res_list[i] = res_list[i].to(w.device)
  386. return res_list