Skip to content

torchmx.mx_quantization_utils

unpack_bfloat16

[view source]

1
2
3
4
def unpack_bfloat16(
    x: torch.Tensor,
    dtype: torch.dtype = torch.uint8
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Extract the sign, exponent, and mantissa from a bfloat16 tensor

Arguments:

  • x - torch.Tensor, the input bfloat16 tensor
  • dtype - torch.dtype, the dtype to cast the output tensors to. Default is torch.uint8

Returns:

  • sign - torch.Tensor, the sign of the tensor in uint8
  • exponent - torch.Tensor, the exponent of the tensor in uint8
  • mantissa - torch.Tensor, the mantissa of the tensor in uint8

unpack_fp32

[view source]

def unpack_fp32(
        x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Unpacks the given FP32 tensor to its components.

Arguments:

  • x torch.Tensor - The packed FP32 tensor.

Returns:

  • sign torch.Tensor - The sign bit tensor.
  • exponent torch.Tensor - The exponent tensor.
  • mantissa torch.Tensor - The mantissa tensor..

unpack_fp64

[view source]

def unpack_fp64(
        x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Unpacks the given FP64 tensor to its components.

Arguments:

  • x torch.Tensor - The packed FP64 tensor.

Returns:

  • sign torch.Tensor - The sign bit tensor.
  • exponent torch.Tensor - The exponent tensor.
  • mantissa torch.Tensor - The mantissa tensor..

dequantize_to_dtype

[view source]

1
2
3
4
5
def dequantize_to_dtype(data_lp: torch.Tensor,
                        elem_dtype: dtypes.DType,
                        target_dtype: torch.dtype,
                        packing_dim: int = -1,
                        is_packed_fp4: bool = True) -> torch.Tensor

Dequantizes a elem_dtype packed as unit8 to the target_dtype by using an intermediate bfloa16 representation.

Arguments:

  • data_lp torch.Tensor - The input tensor in low-precision format (must be of dtype torch.uint8).
  • elem_dtype dtypes.DType - The input element data type
  • target_dtype torch.dtype - The target data type to which the tensor will be dequantized.
  • is_packed_fp4 bool, optional - A flag indicating whether the input tensor is packed in FP4 format. Defaults to True.
  • packing_dim int - The dimension along which the uint4 data is packed, default is -1.

Returns:

  • torch.Tensor - The dequantized tensor in the specified target data type.

Raises:

  • AssertionError - If the element data type is not supported or if the input tensor is not of dtype torch.uint8.

round_to_even

[view source]

def round_to_even(mantissa: torch.Tensor,
                  mantissa_shift: torch.Tensor | int) -> torch.Tensor

Round a mantissa to the nearest even value using a tensor of shift values.

Arguments:

mantissa (torch.Tensor) : A tensor containing the mantissa values to be rounded. - mantissa_shift torch.Tensor | int - A tensor containing the shift values to be applied to each element of the mantissa tensor. The size of the mantissa_shift tensor should match the size of the mantissa tensor. Alternatively, a single integer value can be provided, in which case the same shift value will be applied to all elements of the mantissa tensor.

Returns:

torch.Tensor A tensor containing the mantissa values rounded to the nearest even value. The size of the output tensor will match the input mantissa tensor.

Notes


  • The rounding follows the "round half to even" rule, where if the value to be discarded is exactly halfway between two integers, the result is rounded to the nearest even number.
  • This function supports element-wise operations, where the shifting is applied to each element of the mantissa according to the corresponding value in mantissa_shift.

Examples


mantissa = torch.tensor([0b1010011, 0b1101101], dtype=torch.int32) mantissa_shift = torch.tensor([2, 3], dtype=torch.int32) round_to_even(mantissa, mantissa_shift) tensor([41, 27])

n_ones

[view source]

def n_ones(n: int) -> int

Returns a number with n ones in binary representation. for example: _n_ones(3) = 0b111 = 7

leading_one_position

[view source]

def leading_one_position(mantissa: torch.Tensor, mantissa_size: int = 7)

Determine the position of the leading one bit in each element of the input tensor with LBS at position 0. If there is no 1 in the mantissa, the function returns -1.

Arguments:

  • mantissa torch.Tensor - A tensor containing the mantissa values to be analyzed. Each element should be an integer.

Returns:

  • torch.Tensor - the position of the leading one bit in each element of the input tensor.

quantize_mx_with_e8m0_shared_exponent_hw_exact

[view source]

1
2
3
4
5
def quantize_mx_with_e8m0_shared_exponent_hw_exact(
        data_hp: torch.Tensor,
        elem_dtype: dtypes.DType,
        shared_exponent: torch.Tensor,
        orig_shape: Optional[torch.Size] = None) -> torch.Tensor

A hardware-exact MX quantization function that handles the division and conversion to that target element data type explicitly.

Arguments:

  • data_hp torch.Tensor - The high precision input tensor, (dtype=torch.bfloat16).
  • elem_dtype dtypes.DType - The target element data type for quantization.
  • shared_exponent torch.Tensor - The E8M0 scale shared exponent (dtype=torch.uint8).
  • orig_shape torch.Size - The original shape of the input tensor, used to reshape the output tensor. Optional, defaults to None.

Returns:

  • torch.Tensor dtype=torch.uint8 - The quantized tensor in the target lower precision format.

Raises:

  • AssertionError - If the provided elem_dtype is not supported.

get_fp_scale

[view source]

def get_fp_scale(shared_exp_e8m0: torch.Tensor) -> torch.Tensor

Takes the shared exponent of the MX scale, FP8(0-8-0), as a biased uint8 exponent

Arguments:

  • shared_exp_e8m0 torch.Tensor - the shared exponent of the FP8(0-8-0) scale

Returns:

  • torch.Tensor - FP32 scale, 2**(shared_exponent - 127), with NaNs handling

quantize_mx_with_e8m0_shared_exponent_simulated

[view source]

1
2
3
4
5
def quantize_mx_with_e8m0_shared_exponent_simulated(
        data_hp: torch.Tensor,
        elem_dtype: dtypes.DType,
        shared_exponent: torch.Tensor,
        orig_shape: Optional[torch.Size] = None) -> torch.Tensor

Simulated MX quantization function inspired by torchao. It accepts high precision input tensor (data_hp), the MX scale shared exponent (shared_exponent), and returns the quantized tensor in the elem_dtype. The steps are: 1. normalize data_hp by performing a single-precision division with an MX scale in torch.float32 2. quantize the normalized data_hp to the target elem_dtype by using the native torhcao function

We call this implementation simulated because it is not an efficient hardware implementation

Arguments:

  • data_hp torch.Tensor - The high precision input tensor, dtype is either torch.bfloat16 or torch.float
  • elem_dtype dtypes.DType - The target element data type for quantization.
  • shared_exponent torch.Tensor - The E8M0 scale shared exponent (dtype=torch.uint8).
  • orig_shape torch.Size - The original shape of the input tensor, used to reshape the output tensor. Optional, defaults to None.

Returns:

  • torch.Tensor dtype=torch.uint8 - The quantized tensor in the target lower precision format.

Raises:

  • AssertionError - If the provided elem_dtype is not supported.

get_e8m0_shared_exponent

[view source]

def get_e8m0_shared_exponent(data_hp: torch.Tensor,
                             elem_dtype: dtypes.DType) -> torch.Tensor

Computes the shared exponent for a given high-precision tensor.

Arguments:

  • data_hp torch.Tensor - High-precision input tensor, with block size as the last dimension. dtype must be torch.bfloat16 or torch.float.
  • elem_dtype dtypes.DType - target element dtype

Returns:

  • torch.Tensor - MX-scale exponent tensor as torch.uint8