torchmx.mx_quantization_utils
unpack_bfloat16
Extract the sign, exponent, and mantissa from a bfloat16 tensor
Arguments:
x- torch.Tensor, the input bfloat16 tensordtype- torch.dtype, the dtype to cast the output tensors to. Default is torch.uint8
Returns:
sign- torch.Tensor, the sign of the tensor in uint8exponent- torch.Tensor, the exponent of the tensor in uint8mantissa- torch.Tensor, the mantissa of the tensor in uint8
unpack_fp32
Unpacks the given FP32 tensor to its components.
Arguments:
xtorch.Tensor - The packed FP32 tensor.
Returns:
signtorch.Tensor - The sign bit tensor.exponenttorch.Tensor - The exponent tensor.mantissatorch.Tensor - The mantissa tensor..
unpack_fp64
Unpacks the given FP64 tensor to its components.
Arguments:
xtorch.Tensor - The packed FP64 tensor.
Returns:
signtorch.Tensor - The sign bit tensor.exponenttorch.Tensor - The exponent tensor.mantissatorch.Tensor - The mantissa tensor..
dequantize_to_dtype
Dequantizes a elem_dtype packed as unit8 to the target_dtype by using an intermediate bfloa16 representation.
Arguments:
data_lptorch.Tensor - The input tensor in low-precision format (must be of dtype torch.uint8).elem_dtypedtypes.DType - The input element data typetarget_dtypetorch.dtype - The target data type to which the tensor will be dequantized.is_packed_fp4bool, optional - A flag indicating whether the input tensor is packed in FP4 format. Defaults to True.packing_dimint - The dimension along which the uint4 data is packed, default is -1.
Returns:
torch.Tensor- The dequantized tensor in the specified target data type.
Raises:
AssertionError- If the element data type is not supported or if the input tensor is not of dtype torch.uint8.
round_to_even
Round a mantissa to the nearest even value using a tensor of shift values.
Arguments:
mantissa (torch.Tensor) : A tensor containing the mantissa values to be rounded.
- mantissa_shift torch.Tensor | int - A tensor containing the shift values
to be applied to each element of the mantissa tensor.
The size of the mantissa_shift tensor should match the size of the mantissa tensor.
Alternatively, a single integer value can be provided, in which case the same shift value
will be applied to all elements of the mantissa tensor.
Returns:
torch.Tensor
A tensor containing the mantissa values rounded to the nearest even value.
The size of the output tensor will match the input mantissa tensor.
Notes
- The rounding follows the "round half to even" rule, where if the value to be discarded is exactly halfway between two integers, the result is rounded to the nearest even number.
- This function supports element-wise operations, where the shifting is applied
to each element of the mantissa according to the corresponding value in
mantissa_shift.
Examples
mantissa = torch.tensor([0b1010011, 0b1101101], dtype=torch.int32) mantissa_shift = torch.tensor([2, 3], dtype=torch.int32) round_to_even(mantissa, mantissa_shift) tensor([41, 27])
n_ones
Returns a number with n ones in binary representation. for example: _n_ones(3) = 0b111 = 7
leading_one_position
Determine the position of the leading one bit in each element of the input tensor with LBS at position 0. If there is no 1 in the mantissa, the function returns -1.
Arguments:
mantissatorch.Tensor - A tensor containing the mantissa values to be analyzed. Each element should be an integer.
Returns:
torch.Tensor- the position of the leading one bit in each element of the input tensor.
quantize_mx_with_e8m0_shared_exponent_hw_exact
A hardware-exact MX quantization function that handles the division and conversion to that target element data type explicitly.
Arguments:
data_hptorch.Tensor - The high precision input tensor, (dtype=torch.bfloat16).elem_dtypedtypes.DType - The target element data type for quantization.shared_exponenttorch.Tensor - The E8M0 scale shared exponent (dtype=torch.uint8).orig_shapetorch.Size - The original shape of the input tensor, used to reshape the output tensor. Optional, defaults to None.
Returns:
torch.Tensordtype=torch.uint8 - The quantized tensor in the target lower precision format.
Raises:
AssertionError- If the provided elem_dtype is not supported.
get_fp_scale
Takes the shared exponent of the MX scale, FP8(0-8-0), as a biased uint8 exponent
Arguments:
shared_exp_e8m0torch.Tensor - the shared exponent of the FP8(0-8-0) scale
Returns:
torch.Tensor- FP32 scale, 2**(shared_exponent - 127), with NaNs handling
quantize_mx_with_e8m0_shared_exponent_simulated
Simulated MX quantization function inspired by torchao. It accepts high precision input tensor (data_hp), the MX scale shared exponent (shared_exponent), and returns the quantized tensor in the elem_dtype. The steps are: 1. normalize data_hp by performing a single-precision division with an MX scale in torch.float32 2. quantize the normalized data_hp to the target elem_dtype by using the native torhcao function
We call this implementation simulated because it is not an efficient hardware implementation
Arguments:
data_hptorch.Tensor - The high precision input tensor, dtype is either torch.bfloat16 or torch.floatelem_dtypedtypes.DType - The target element data type for quantization.shared_exponenttorch.Tensor - The E8M0 scale shared exponent (dtype=torch.uint8).orig_shapetorch.Size - The original shape of the input tensor, used to reshape the output tensor. Optional, defaults to None.
Returns:
torch.Tensordtype=torch.uint8 - The quantized tensor in the target lower precision format.
Raises:
AssertionError- If the provided elem_dtype is not supported.
get_e8m0_shared_exponent
Computes the shared exponent for a given high-precision tensor.
Arguments:
data_hptorch.Tensor - High-precision input tensor, with block size as the last dimension. dtype must be torch.bfloat16 or torch.float.elem_dtypedtypes.DType - target element dtype
Returns:
torch.Tensor- MX-scale exponent tensor as torch.uint8