123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585 |
- # ruff: noqa
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- # This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3).
- # It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
- # 1. No encodings are reserved for special values (+/-inf, NaN).
- # 2. When downcasting from FP32 to FPx,
- # - Rounding mode is round to nearest, ties to even.
- # - Values outside the representable range of FPx after rounding are clamped to the maximum FPx
- # magnitude (sign is preserved).
- from functools import reduce
- from typing import Tuple
- import torch
- from torch import Tensor
- def _n_ones(n: int) -> int:
- return (1 << n) - 1
- EBITS_F32, MBITS_F32 = 8, 23
- F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
- # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
- _SPLIT_K_MAP = [
- { # tokens: [1, 64]
- 3072: 18,
- 4096: 13,
- 5120: 10,
- 6144: 9,
- 8192: 6,
- 10240: 5,
- 14336: 7,
- 28672: 7,
- 57344: 7
- },
- { # tokens: [65:128]
- 3072: 9,
- 4096: 6,
- 5120: 5,
- 6144: 9,
- 8192: 3,
- 10240: 5,
- 14336: 7,
- 28672: 7,
- 57344: 6
- },
- { # tokens: [129:192]
- 3072: 6,
- 4096: 4,
- 5120: 7,
- 6144: 3,
- 8192: 2,
- 10240: 5,
- 14336: 5,
- 28672: 5,
- 57344: 4
- },
- { # tokens: [193:256]
- 3072: 9,
- 4096: 3,
- 5120: 5,
- 6144: 2,
- 8192: 5,
- 10240: 4,
- 14336: 8,
- 28672: 6,
- 57344: 4
- },
- { # tokens: [257:320]
- 3072: 7,
- 4096: 5,
- 5120: 2,
- 6144: 5,
- 8192: 4,
- 10240: 1,
- 14336: 3,
- 28672: 3,
- 57344: 4
- },
- { # tokens: [321:384]
- 3072: 3,
- 4096: 2,
- 5120: 5,
- 6144: 3,
- 8192: 1,
- 10240: 8,
- 14336: 3,
- 28672: 4,
- 57344: 3
- },
- { # tokens: [385:448]
- 3072: 5,
- 4096: 7,
- 5120: 3,
- 6144: 5,
- 8192: 7,
- 10240: 3,
- 14336: 1,
- 28672: 1,
- 57344: 3
- },
- { # tokens: [449:512]
- 3072: 2,
- 4096: 5,
- 5120: 4,
- 6144: 1,
- 8192: 5,
- 10240: 2,
- 14336: 6,
- 28672: 4,
- 57344: 1
- },
- { # tokens: [513:576]
- 3072: 2,
- 4096: 3,
- 5120: 1,
- 6144: 1,
- 8192: 3,
- 10240: 3,
- 14336: 3,
- 28672: 1,
- 57344: 1
- },
- { # tokens: [577:640]
- 3072: 5,
- 4096: 4,
- 5120: 1,
- 6144: 4,
- 8192: 2,
- 10240: 1,
- 14336: 1,
- 28672: 1,
- 57344: 1
- },
- { # tokens: [641:704]
- 3072: 3,
- 4096: 1,
- 5120: 2,
- 6144: 2,
- 8192: 1,
- 10240: 2,
- 14336: 1,
- 28672: 1,
- 57344: 1
- },
- { # tokens: [705:768]
- 3072: 3,
- 4096: 1,
- 5120: 3,
- 6144: 2,
- 8192: 1,
- 10240: 1,
- 14336: 1,
- 28672: 1,
- 57344: 1
- }
- ]
- def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
- """Convert FP32 numbers to sub-byte floating point numbers with the given
- number of exponent and mantissa bits.
- Input: torch.Tensor of dtype torch.float
- Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
- in the least significant bits. e.g.
- fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
- fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
- Note: there are no special values (NaN, inf) support in this code. Values
- outside the representable range of FPx after rounding are clamped to the
- maximum FPx magnitude (sign is preserved).
- Code below is an adaptation of https://fburl.com/code/ciwofcg4
- Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
- Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
- """
- assert x.dtype == torch.float
- assert 1 + ebits + mbits <= 8
- # calculate constants
- exp_bias = _n_ones(ebits - 1)
- max_int = _n_ones(ebits + mbits)
- sign_mask = 1 << (ebits + mbits)
- # TODO document this better
- magic_adder = _n_ones(MBITS_F32 - mbits - 1)
- # all E bits and M bits are 1s
- max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
- # E bits = 1, M bits = 0
- min_normal = 2 ** (1 - exp_bias)
- denorm_exp = (
- # exp bias conversion between formats
- (F32_EXP_BIAS - exp_bias)
- # mantissa length difference between formats
- + (MBITS_F32 - mbits)
- # add one to encoded exponent for denormalized numbers
- + 1
- )
- denorm_mask_int = denorm_exp << MBITS_F32
- # reinterpret int32 as float32
- denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
- # save the sign
- # Note that we have torch.uint32, but some ops like cpu bit shifts
- # do not work on it. So, we stay in int32.
- x = x.view(torch.int32)
- sign = x & 0x80000000
- # set everything to positive, will add sign back at the end
- x = x ^ sign
- # TODO: can the branch floating point comparisons below be done without
- # converting to float? probably but need to verify
- x = x.view(torch.float)
- # rewrite saturate/denorm/norm branches without explicit data dependent
- # control flow, to be more compiler friendly
- saturate_mask = x >= max_normal
- denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
- normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
- #
- # branch 1: saturate to max val - handled later in the code which combines
- # the branches
- #
- #
- # branch 2: to conversion to denormal as well as rounding up to normal
- #
- denormal_x = x + denorm_mask_float
- denormal_x = denormal_x.view(torch.int32)
- denormal_x -= denorm_mask_int
- denormal_x = denormal_x.to(torch.uint8)
- #
- # branch 3: stay in normal range, adjust the exponent and round
- #
- normal_x = x.view(torch.int32)
- # resulting mantissa is odd
- mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
- # update exponent, rounding bias part 1
- val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
- normal_x += val_to_add
- # rounding bias part 2
- normal_x += mant_odd
- # take the bits!
- normal_x = normal_x >> (MBITS_F32 - mbits)
- normal_x = normal_x.to(torch.uint8)
- #
- # combine the branches
- #
- x = torch.full_like(x, max_int, dtype=torch.uint8)
- x = torch.where(denormal_mask, denormal_x, x)
- x = torch.where(normal_mask, normal_x, x)
- # add sign back
- sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
- sign_lp = sign_lp.to(torch.uint8)
- # Right shift of a negative signed integer can fill the least significant
- # bits with either 1s or 0s, depending on the implementation. Since PyTorch
- # doesn't have an uint32 dtype, we mask out these bits to get just the
- # f4 sign bit
- sign_lp = sign_lp & sign_mask
- x = x | sign_lp
- return x.to(torch.uint8)
- # TODO(future): check if LUT for everything is faster than bit shifting,
- # especially for fp4 (only 2^4=16 unique values).
- def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
- """Convert sub-byte floating point numbers with the given number of exponent
- and mantissa bits to FP32.
- Input: torch.Tensor of dtype uint8, where the bit encoding is stored
- in the least significant bits. e.g.
- fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
- fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
- Output: torch.Tensor of dtype fp32 with the dequantized value
- """
- assert x.dtype == torch.uint8
- assert 1 + ebits + mbits <= 8
- sign_mask = 1 << (ebits + mbits)
- exp_bias = _n_ones(ebits - 1)
- mantissa_mask = _n_ones(mbits)
- # save the sign
- sign_lp = x & sign_mask
- # set everything to positive, will add sign back at the end
- x_pos = x ^ sign_lp
- #
- # 1. Calculate zero mask
- #
- zero_mask = x_pos == 0
- #
- # 2. Calculate the denormal path mask
- #
- denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
- #
- # 3. Calculate the normal path
- #
- # calculate the new exponent and shift it to bits 2:9 of the result
- exp_biased_lp = x_pos >> mbits
- exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
- exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
- # shift the mantissa to bits 10:32 of the result
- mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
- mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
- result = exp_biased_f32 | mantissa_f32
- #
- # 4. Add the zero and denormal casts to the already casted normal path
- #
- result[zero_mask] = 0
- denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
- # fast path.
- # without this, performance for FP4_E2M1 is slower by 2x
- if mbits == 1:
- result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
- else:
- # iterate over all possible values of mantissa
- # i=0, j=1
- # i=1, j=10,11
- # i=2, j=100,101,110,111
- # and so on
- for i in range(mbits):
- for mantissa_cmp in range(1 << i, 1 << (i+1)):
- # left shift mantissa until it overflows (create an implicit 1)
- # subtract exponent by the same amount
- left_shift = mbits - i
- mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
- exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
- # we can update this in-place since the values won't overlap
- # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
- # thus we use + instead of | here
- mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32
- result = torch.where(denormal_mask, mantissa_lp_int32, result)
- # add sign back
- sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
- result = result | sign_f32
- return result.view(torch.float)
- def quant_llm_linear(
- EXPONENT: int,
- MANTISSA: int,
- _in_feats: Tensor,
- _weights: Tensor,
- _scales: Tensor,
- splitK: int = 1,
- ) -> Tensor:
- """
- Quant-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details.
- Arguments
- EXPONENT: number of exponent bits
- MANTISSA: number of mantissa bits
- _in_feats: input activations in FP16
- _weights: packed FPx weights
- _scales: scale
- splitK: split K
- Returns
- output of linear layer
- """
- return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK)
- _ONES_TABLE = [_n_ones(i) for i in range(8)]
- def _pack(x: Tensor, n_bits: int) -> Tensor:
- return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)])
- def _unpack(x: Tensor, n_bits: int) -> Tensor:
- return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2)
- # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
- def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
- # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8
- # thus, we need to reverse byte order within a uint32 word.
- x = x.reshape(-1, 4).flip(1)
- x = _unpack(x, n_bits)
- x = x.view(-1, 4 * (8 // n_bits))
- if not undo:
- bit_order = {
- 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31,
- 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30],
- 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14],
- 4: [1, 5, 3, 7, 0, 4, 2, 6],
- }[n_bits]
- else:
- # this is inverse of the above, obtained by running
- # [v.index(i) for i in range(len(v))]
- bit_order = {
- 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11,
- 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15],
- 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7],
- 4: [4, 0, 6, 2, 5, 1, 7, 3],
- }[n_bits]
- x = x[:, bit_order]
- x = _pack(x, n_bits)
- # reverse byte order within a uint32 word again.
- x = x.reshape(-1, 4).flip(1)
- return x.flatten()
- # this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing
- # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h
- def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
- assert tensor.ndim == 2, tensor.dtype == torch.uint8
- M, N = tensor.shape
- assert (M % 64 == 0) and (N % 64 == 0)
- # Pass 1 from original code
- tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8)
- tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6)
- tensor = tensor.reshape(-1, 32, 2)
- tensor = tensor.permute(1, 0, 2)
- tensor = tensor.flatten()
- used_bits = 0
- fragments = []
- for y in [1, 2, 4]:
- if nbits & y:
- mask = (1 << y) - 1
- tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask
- tensor_ybit = _pack(tensor_ybit, y)
- tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2)
- tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y)
- fragments.append(tensor_ybit)
- used_bits += y
- return torch.cat(fragments, dim=0).view(M, -1)
- # more optimized version of _pack_tc_fpx() for FP6 by merging ops
- def _pack_tc_fp6(tensor: Tensor) -> Tensor:
- assert tensor.ndim == 2, tensor.dtype == torch.uint8
- M, N = tensor.shape
- assert (M % 64 == 0) and (N % 64 == 0)
- tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8)
- tensor = tensor.flip(3)
- tensor_2bit = (tensor >> 4) & 0b11
- tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6)
- tensor_2bit = _pack(tensor_2bit.flatten(), 2)
- tensor_4bit = tensor & 0b1111
- tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6)
- tensor_4bit = _pack(tensor_4bit.flatten(), 4)
- return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1)
- # currently only optimize for TC-FP6 packing
- def pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
- if nbits == 6:
- return _pack_tc_fp6(tensor)
- return _pack_tc_fpx(tensor, nbits)
- def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]:
- # _n_ones() is not compatible with torch.compile() due to << operator
- # https://github.com/pytorch/pytorch/issues/119152
- # exp_bias = _n_ones(ebits - 1)
- # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
- # workaround: global lookup table
- exp_bias = _ONES_TABLE[ebits - 1]
- max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
- tensor = tensor.float()
- scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
- tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
- tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits)
- return tensor_tc_fpx, scale.half()
- # inverse of _pack_tc_fpx()
- def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
- assert tensor.ndim == 2 and tensor.dtype == torch.uint8
- M = tensor.shape[0]
- size = tensor.numel()
- tensor = tensor.flatten()
- offset = 0
- used_bits = 0
- tensor_fpx = None
- for y in [1, 2, 4]:
- if nbits & y:
- size_ybit = size // nbits * y
- tensor_ybit = tensor[offset : offset + size_ybit]
- offset += size_ybit
- tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3
- tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2
- tensor_ybit = _unpack(tensor_ybit.flatten(), y)
- tensor_ybit = tensor_ybit << (nbits - used_bits - y)
- used_bits += y
- if tensor_fpx is None:
- tensor_fpx = tensor_ybit
- else:
- tensor_fpx |= tensor_ybit
- # undo Pass 1
- tensor_fpx = tensor_fpx.view(32, -1, 2).permute(1, 0, 2)
- tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8)
- tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6)
- tensor_fpx = tensor_fpx.reshape(M, -1)
- return tensor_fpx
- # more optimized version of _unpack_tc_fpx() for FP6 by merging ops
- # inverse of _unpack_tc_fp6()
- def _unpack_tc_fp6(tensor: Tensor) -> Tensor:
- assert tensor.ndim == 2 and tensor.dtype == torch.uint8
- M = tensor.shape[0]
- N = tensor.shape[1] // 3 * 4
- assert (M % 64 == 0) and (N % 64 == 0)
- size_2bit = M * N // 4
- size_4bit = M * N // 2
- tensor = tensor.view(-1)
- assert tensor.numel() == size_2bit + size_4bit
- tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit])
- tensor_2bit = _unpack(tensor_2bit, 2)
- tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2)
- tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4)
- tensor_4bit = _unpack(tensor_4bit, 4)
- tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2)
- tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5)
- tensor_fp6 = (tensor_2bit << 4) | tensor_4bit
- tensor_fp6 = tensor_fp6.flip(3).reshape(M, N)
- return tensor_fp6
- def unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
- if nbits == 6:
- return _unpack_tc_fp6(tensor)
- return _unpack_tc_fpx(tensor, nbits)
- def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor:
- fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits)
- tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits)
- if scale is not None:
- tensor = tensor * scale.float().view(-1, 1)
- return tensor
|