fp6_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. # ruff: noqa
  2. # Copyright (c) Meta Platforms, Inc. and affiliates.
  3. # All rights reserved.
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3).
  7. # It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
  8. # 1. No encodings are reserved for special values (+/-inf, NaN).
  9. # 2. When downcasting from FP32 to FPx,
  10. # - Rounding mode is round to nearest, ties to even.
  11. # - Values outside the representable range of FPx after rounding are clamped to the maximum FPx
  12. # magnitude (sign is preserved).
  13. from functools import reduce
  14. from typing import Tuple
  15. import torch
  16. from torch import Tensor
  17. def _n_ones(n: int) -> int:
  18. return (1 << n) - 1
  19. EBITS_F32, MBITS_F32 = 8, 23
  20. F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
  21. # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
  22. _SPLIT_K_MAP = [
  23. { # tokens: [1, 64]
  24. 3072: 18,
  25. 4096: 13,
  26. 5120: 10,
  27. 6144: 9,
  28. 8192: 6,
  29. 10240: 5,
  30. 14336: 7,
  31. 28672: 7,
  32. 57344: 7
  33. },
  34. { # tokens: [65:128]
  35. 3072: 9,
  36. 4096: 6,
  37. 5120: 5,
  38. 6144: 9,
  39. 8192: 3,
  40. 10240: 5,
  41. 14336: 7,
  42. 28672: 7,
  43. 57344: 6
  44. },
  45. { # tokens: [129:192]
  46. 3072: 6,
  47. 4096: 4,
  48. 5120: 7,
  49. 6144: 3,
  50. 8192: 2,
  51. 10240: 5,
  52. 14336: 5,
  53. 28672: 5,
  54. 57344: 4
  55. },
  56. { # tokens: [193:256]
  57. 3072: 9,
  58. 4096: 3,
  59. 5120: 5,
  60. 6144: 2,
  61. 8192: 5,
  62. 10240: 4,
  63. 14336: 8,
  64. 28672: 6,
  65. 57344: 4
  66. },
  67. { # tokens: [257:320]
  68. 3072: 7,
  69. 4096: 5,
  70. 5120: 2,
  71. 6144: 5,
  72. 8192: 4,
  73. 10240: 1,
  74. 14336: 3,
  75. 28672: 3,
  76. 57344: 4
  77. },
  78. { # tokens: [321:384]
  79. 3072: 3,
  80. 4096: 2,
  81. 5120: 5,
  82. 6144: 3,
  83. 8192: 1,
  84. 10240: 8,
  85. 14336: 3,
  86. 28672: 4,
  87. 57344: 3
  88. },
  89. { # tokens: [385:448]
  90. 3072: 5,
  91. 4096: 7,
  92. 5120: 3,
  93. 6144: 5,
  94. 8192: 7,
  95. 10240: 3,
  96. 14336: 1,
  97. 28672: 1,
  98. 57344: 3
  99. },
  100. { # tokens: [449:512]
  101. 3072: 2,
  102. 4096: 5,
  103. 5120: 4,
  104. 6144: 1,
  105. 8192: 5,
  106. 10240: 2,
  107. 14336: 6,
  108. 28672: 4,
  109. 57344: 1
  110. },
  111. { # tokens: [513:576]
  112. 3072: 2,
  113. 4096: 3,
  114. 5120: 1,
  115. 6144: 1,
  116. 8192: 3,
  117. 10240: 3,
  118. 14336: 3,
  119. 28672: 1,
  120. 57344: 1
  121. },
  122. { # tokens: [577:640]
  123. 3072: 5,
  124. 4096: 4,
  125. 5120: 1,
  126. 6144: 4,
  127. 8192: 2,
  128. 10240: 1,
  129. 14336: 1,
  130. 28672: 1,
  131. 57344: 1
  132. },
  133. { # tokens: [641:704]
  134. 3072: 3,
  135. 4096: 1,
  136. 5120: 2,
  137. 6144: 2,
  138. 8192: 1,
  139. 10240: 2,
  140. 14336: 1,
  141. 28672: 1,
  142. 57344: 1
  143. },
  144. { # tokens: [705:768]
  145. 3072: 3,
  146. 4096: 1,
  147. 5120: 3,
  148. 6144: 2,
  149. 8192: 1,
  150. 10240: 1,
  151. 14336: 1,
  152. 28672: 1,
  153. 57344: 1
  154. }
  155. ]
  156. def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
  157. """Convert FP32 numbers to sub-byte floating point numbers with the given
  158. number of exponent and mantissa bits.
  159. Input: torch.Tensor of dtype torch.float
  160. Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
  161. in the least significant bits. e.g.
  162. fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
  163. fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
  164. Note: there are no special values (NaN, inf) support in this code. Values
  165. outside the representable range of FPx after rounding are clamped to the
  166. maximum FPx magnitude (sign is preserved).
  167. Code below is an adaptation of https://fburl.com/code/ciwofcg4
  168. Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
  169. Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
  170. """
  171. assert x.dtype == torch.float
  172. assert 1 + ebits + mbits <= 8
  173. # calculate constants
  174. exp_bias = _n_ones(ebits - 1)
  175. max_int = _n_ones(ebits + mbits)
  176. sign_mask = 1 << (ebits + mbits)
  177. # TODO document this better
  178. magic_adder = _n_ones(MBITS_F32 - mbits - 1)
  179. # all E bits and M bits are 1s
  180. max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
  181. # E bits = 1, M bits = 0
  182. min_normal = 2 ** (1 - exp_bias)
  183. denorm_exp = (
  184. # exp bias conversion between formats
  185. (F32_EXP_BIAS - exp_bias)
  186. # mantissa length difference between formats
  187. + (MBITS_F32 - mbits)
  188. # add one to encoded exponent for denormalized numbers
  189. + 1
  190. )
  191. denorm_mask_int = denorm_exp << MBITS_F32
  192. # reinterpret int32 as float32
  193. denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
  194. # save the sign
  195. # Note that we have torch.uint32, but some ops like cpu bit shifts
  196. # do not work on it. So, we stay in int32.
  197. x = x.view(torch.int32)
  198. sign = x & 0x80000000
  199. # set everything to positive, will add sign back at the end
  200. x = x ^ sign
  201. # TODO: can the branch floating point comparisons below be done without
  202. # converting to float? probably but need to verify
  203. x = x.view(torch.float)
  204. # rewrite saturate/denorm/norm branches without explicit data dependent
  205. # control flow, to be more compiler friendly
  206. saturate_mask = x >= max_normal
  207. denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
  208. normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
  209. #
  210. # branch 1: saturate to max val - handled later in the code which combines
  211. # the branches
  212. #
  213. #
  214. # branch 2: to conversion to denormal as well as rounding up to normal
  215. #
  216. denormal_x = x + denorm_mask_float
  217. denormal_x = denormal_x.view(torch.int32)
  218. denormal_x -= denorm_mask_int
  219. denormal_x = denormal_x.to(torch.uint8)
  220. #
  221. # branch 3: stay in normal range, adjust the exponent and round
  222. #
  223. normal_x = x.view(torch.int32)
  224. # resulting mantissa is odd
  225. mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
  226. # update exponent, rounding bias part 1
  227. val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
  228. normal_x += val_to_add
  229. # rounding bias part 2
  230. normal_x += mant_odd
  231. # take the bits!
  232. normal_x = normal_x >> (MBITS_F32 - mbits)
  233. normal_x = normal_x.to(torch.uint8)
  234. #
  235. # combine the branches
  236. #
  237. x = torch.full_like(x, max_int, dtype=torch.uint8)
  238. x = torch.where(denormal_mask, denormal_x, x)
  239. x = torch.where(normal_mask, normal_x, x)
  240. # add sign back
  241. sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
  242. sign_lp = sign_lp.to(torch.uint8)
  243. # Right shift of a negative signed integer can fill the least significant
  244. # bits with either 1s or 0s, depending on the implementation. Since PyTorch
  245. # doesn't have an uint32 dtype, we mask out these bits to get just the
  246. # f4 sign bit
  247. sign_lp = sign_lp & sign_mask
  248. x = x | sign_lp
  249. return x.to(torch.uint8)
  250. # TODO(future): check if LUT for everything is faster than bit shifting,
  251. # especially for fp4 (only 2^4=16 unique values).
  252. def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
  253. """Convert sub-byte floating point numbers with the given number of exponent
  254. and mantissa bits to FP32.
  255. Input: torch.Tensor of dtype uint8, where the bit encoding is stored
  256. in the least significant bits. e.g.
  257. fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
  258. fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
  259. Output: torch.Tensor of dtype fp32 with the dequantized value
  260. """
  261. assert x.dtype == torch.uint8
  262. assert 1 + ebits + mbits <= 8
  263. sign_mask = 1 << (ebits + mbits)
  264. exp_bias = _n_ones(ebits - 1)
  265. mantissa_mask = _n_ones(mbits)
  266. # save the sign
  267. sign_lp = x & sign_mask
  268. # set everything to positive, will add sign back at the end
  269. x_pos = x ^ sign_lp
  270. #
  271. # 1. Calculate zero mask
  272. #
  273. zero_mask = x_pos == 0
  274. #
  275. # 2. Calculate the denormal path mask
  276. #
  277. denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
  278. #
  279. # 3. Calculate the normal path
  280. #
  281. # calculate the new exponent and shift it to bits 2:9 of the result
  282. exp_biased_lp = x_pos >> mbits
  283. exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
  284. exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
  285. # shift the mantissa to bits 10:32 of the result
  286. mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
  287. mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
  288. result = exp_biased_f32 | mantissa_f32
  289. #
  290. # 4. Add the zero and denormal casts to the already casted normal path
  291. #
  292. result[zero_mask] = 0
  293. denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
  294. # fast path.
  295. # without this, performance for FP4_E2M1 is slower by 2x
  296. if mbits == 1:
  297. result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
  298. else:
  299. # iterate over all possible values of mantissa
  300. # i=0, j=1
  301. # i=1, j=10,11
  302. # i=2, j=100,101,110,111
  303. # and so on
  304. for i in range(mbits):
  305. for mantissa_cmp in range(1 << i, 1 << (i+1)):
  306. # left shift mantissa until it overflows (create an implicit 1)
  307. # subtract exponent by the same amount
  308. left_shift = mbits - i
  309. mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
  310. exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
  311. # we can update this in-place since the values won't overlap
  312. # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
  313. # thus we use + instead of | here
  314. mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32
  315. result = torch.where(denormal_mask, mantissa_lp_int32, result)
  316. # add sign back
  317. sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
  318. result = result | sign_f32
  319. return result.view(torch.float)
  320. def quant_llm_linear(
  321. EXPONENT: int,
  322. MANTISSA: int,
  323. _in_feats: Tensor,
  324. _weights: Tensor,
  325. _scales: Tensor,
  326. splitK: int = 1,
  327. ) -> Tensor:
  328. """
  329. Quant-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details.
  330. Arguments
  331. EXPONENT: number of exponent bits
  332. MANTISSA: number of mantissa bits
  333. _in_feats: input activations in FP16
  334. _weights: packed FPx weights
  335. _scales: scale
  336. splitK: split K
  337. Returns
  338. output of linear layer
  339. """
  340. return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK)
  341. _ONES_TABLE = [_n_ones(i) for i in range(8)]
  342. def _pack(x: Tensor, n_bits: int) -> Tensor:
  343. return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)])
  344. def _unpack(x: Tensor, n_bits: int) -> Tensor:
  345. return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2)
  346. # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
  347. def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
  348. # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8
  349. # thus, we need to reverse byte order within a uint32 word.
  350. x = x.reshape(-1, 4).flip(1)
  351. x = _unpack(x, n_bits)
  352. x = x.view(-1, 4 * (8 // n_bits))
  353. if not undo:
  354. bit_order = {
  355. 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31,
  356. 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30],
  357. 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14],
  358. 4: [1, 5, 3, 7, 0, 4, 2, 6],
  359. }[n_bits]
  360. else:
  361. # this is inverse of the above, obtained by running
  362. # [v.index(i) for i in range(len(v))]
  363. bit_order = {
  364. 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11,
  365. 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15],
  366. 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7],
  367. 4: [4, 0, 6, 2, 5, 1, 7, 3],
  368. }[n_bits]
  369. x = x[:, bit_order]
  370. x = _pack(x, n_bits)
  371. # reverse byte order within a uint32 word again.
  372. x = x.reshape(-1, 4).flip(1)
  373. return x.flatten()
  374. # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing
  375. # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h
  376. def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
  377. assert tensor.ndim == 2, tensor.dtype == torch.uint8
  378. M, N = tensor.shape
  379. assert (M % 64 == 0) and (N % 64 == 0)
  380. # Pass 1 from original code
  381. tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8)
  382. tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6)
  383. tensor = tensor.reshape(-1, 32, 2)
  384. tensor = tensor.permute(1, 0, 2)
  385. tensor = tensor.flatten()
  386. used_bits = 0
  387. fragments = []
  388. for y in [1, 2, 4]:
  389. if nbits & y:
  390. mask = (1 << y) - 1
  391. tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask
  392. tensor_ybit = _pack(tensor_ybit, y)
  393. tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2)
  394. tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y)
  395. fragments.append(tensor_ybit)
  396. used_bits += y
  397. return torch.cat(fragments, dim=0).view(M, -1)
  398. # more optimized version of _pack_tc_fpx() for FP6 by merging ops
  399. def _pack_tc_fp6(tensor: Tensor) -> Tensor:
  400. assert tensor.ndim == 2, tensor.dtype == torch.uint8
  401. M, N = tensor.shape
  402. assert (M % 64 == 0) and (N % 64 == 0)
  403. tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8)
  404. tensor = tensor.flip(3)
  405. tensor_2bit = (tensor >> 4) & 0b11
  406. tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6)
  407. tensor_2bit = _pack(tensor_2bit.flatten(), 2)
  408. tensor_4bit = tensor & 0b1111
  409. tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6)
  410. tensor_4bit = _pack(tensor_4bit.flatten(), 4)
  411. return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1)
  412. # currently only optimize for TC-FP6 packing
  413. def pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
  414. if nbits == 6:
  415. return _pack_tc_fp6(tensor)
  416. return _pack_tc_fpx(tensor, nbits)
  417. def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]:
  418. # _n_ones() is not compatible with torch.compile() due to << operator
  419. # https://github.com/pytorch/pytorch/issues/119152
  420. # exp_bias = _n_ones(ebits - 1)
  421. # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
  422. # workaround: global lookup table
  423. exp_bias = _ONES_TABLE[ebits - 1]
  424. max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
  425. tensor = tensor.float()
  426. scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
  427. tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
  428. tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits)
  429. return tensor_tc_fpx, scale.half()
  430. # inverse of _pack_tc_fpx()
  431. def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
  432. assert tensor.ndim == 2 and tensor.dtype == torch.uint8
  433. M = tensor.shape[0]
  434. size = tensor.numel()
  435. tensor = tensor.flatten()
  436. offset = 0
  437. used_bits = 0
  438. tensor_fpx = None
  439. for y in [1, 2, 4]:
  440. if nbits & y:
  441. size_ybit = size // nbits * y
  442. tensor_ybit = tensor[offset : offset + size_ybit]
  443. offset += size_ybit
  444. tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3
  445. tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2
  446. tensor_ybit = _unpack(tensor_ybit.flatten(), y)
  447. tensor_ybit = tensor_ybit << (nbits - used_bits - y)
  448. used_bits += y
  449. if tensor_fpx is None:
  450. tensor_fpx = tensor_ybit
  451. else:
  452. tensor_fpx |= tensor_ybit
  453. # undo Pass 1
  454. tensor_fpx = tensor_fpx.view(32, -1, 2).permute(1, 0, 2)
  455. tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8)
  456. tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6)
  457. tensor_fpx = tensor_fpx.reshape(M, -1)
  458. return tensor_fpx
  459. # more optimized version of _unpack_tc_fpx() for FP6 by merging ops
  460. # inverse of _unpack_tc_fp6()
  461. def _unpack_tc_fp6(tensor: Tensor) -> Tensor:
  462. assert tensor.ndim == 2 and tensor.dtype == torch.uint8
  463. M = tensor.shape[0]
  464. N = tensor.shape[1] // 3 * 4
  465. assert (M % 64 == 0) and (N % 64 == 0)
  466. size_2bit = M * N // 4
  467. size_4bit = M * N // 2
  468. tensor = tensor.view(-1)
  469. assert tensor.numel() == size_2bit + size_4bit
  470. tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit])
  471. tensor_2bit = _unpack(tensor_2bit, 2)
  472. tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2)
  473. tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4)
  474. tensor_4bit = _unpack(tensor_4bit, 4)
  475. tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2)
  476. tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5)
  477. tensor_fp6 = (tensor_2bit << 4) | tensor_4bit
  478. tensor_fp6 = tensor_fp6.flip(3).reshape(M, N)
  479. return tensor_fp6
  480. def unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
  481. if nbits == 6:
  482. return _unpack_tc_fp6(tensor)
  483. return _unpack_tc_fpx(tensor, nbits)
  484. def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor:
  485. fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits)
  486. tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits)
  487. if scale is not None:
  488. tensor = tensor * scale.float().view(-1, 1)
  489. return tensor