torchmx.mx_quantization_utils
unpack_bfloat16
Extract the sign, exponent, and mantissa from a bfloat16 tensor
Arguments:
x
- torch.Tensor, the input bfloat16 tensordtype
- torch.dtype, the dtype to cast the output tensors to. Default is torch.uint8
Returns:
sign
- torch.Tensor, the sign of the tensor in uint8exponent
- torch.Tensor, the exponent of the tensor in uint8mantissa
- torch.Tensor, the mantissa of the tensor in uint8
unpack_fp32
Unpacks the given FP32 tensor to its components.
Arguments:
x
torch.Tensor - The packed FP32 tensor.
Returns:
sign
torch.Tensor - The sign bit tensor.exponent
torch.Tensor - The exponent tensor.mantissa
torch.Tensor - The mantissa tensor..
unpack_fp64
Unpacks the given FP64 tensor to its components.
Arguments:
x
torch.Tensor - The packed FP64 tensor.
Returns:
sign
torch.Tensor - The sign bit tensor.exponent
torch.Tensor - The exponent tensor.mantissa
torch.Tensor - The mantissa tensor..
dequantize_to_dtype
Dequantizes a elem_dtype packed as unit8 to the target_dtype by using an intermediate bfloa16 representation.
Arguments:
data_lp
torch.Tensor - The input tensor in low-precision format (must be of dtype torch.uint8).elem_dtype
dtypes.DType - The input element data typetarget_dtype
torch.dtype - The target data type to which the tensor will be dequantized.is_packed_fp4
bool, optional - A flag indicating whether the input tensor is packed in FP4 format. Defaults to True.packing_dim
int - The dimension along which the uint4 data is packed, default is -1.
Returns:
torch.Tensor
- The dequantized tensor in the specified target data type.
Raises:
AssertionError
- If the element data type is not supported or if the input tensor is not of dtype torch.uint8.
round_to_even
Round a mantissa to the nearest even value using a tensor of shift values.
Arguments:
mantissa (torch.Tensor) : A tensor containing the mantissa values to be rounded.
- mantissa_shift
torch.Tensor | int - A tensor containing the shift values
to be applied to each element of the mantissa
tensor.
The size of the mantissa_shift
tensor should match the size of the mantissa
tensor.
Alternatively, a single integer value can be provided, in which case the same shift value
will be applied to all elements of the mantissa
tensor.
Returns:
torch.Tensor
A tensor containing the mantissa values rounded to the nearest even value.
The size of the output tensor will match the input mantissa
tensor.
Notes
- The rounding follows the "round half to even" rule, where if the value to be discarded is exactly halfway between two integers, the result is rounded to the nearest even number.
- This function supports element-wise operations, where the shifting is applied
to each element of the mantissa according to the corresponding value in
mantissa_shift
.
Examples
mantissa = torch.tensor([0b1010011, 0b1101101], dtype=torch.int32) mantissa_shift = torch.tensor([2, 3], dtype=torch.int32) round_to_even(mantissa, mantissa_shift) tensor([41, 27])
n_ones
Returns a number with n ones in binary representation. for example: _n_ones(3) = 0b111 = 7
leading_one_position
Determine the position of the leading one bit in each element of the input tensor with LBS at position 0. If there is no 1 in the mantissa, the function returns -1.
Arguments:
mantissa
torch.Tensor - A tensor containing the mantissa values to be analyzed. Each element should be an integer.
Returns:
torch.Tensor
- the position of the leading one bit in each element of the input tensor.
quantize_mx_with_e8m0_shared_exponent_hw_exact
A hardware-exact MX quantization function that handles the division and conversion to that target element data type explicitly.
Arguments:
data_hp
torch.Tensor - The high precision input tensor, (dtype=torch.bfloat16).elem_dtype
dtypes.DType - The target element data type for quantization.shared_exponent
torch.Tensor - The E8M0 scale shared exponent (dtype=torch.uint8).orig_shape
torch.Size - The original shape of the input tensor, used to reshape the output tensor. Optional, defaults to None.
Returns:
torch.Tensor
dtype=torch.uint8 - The quantized tensor in the target lower precision format.
Raises:
AssertionError
- If the provided elem_dtype is not supported.
get_fp_scale
Takes the shared exponent of the MX scale, FP8(0-8-0), as a biased uint8 exponent
Arguments:
shared_exp_e8m0
torch.Tensor - the shared exponent of the FP8(0-8-0) scale
Returns:
torch.Tensor
- FP32 scale, 2**(shared_exponent - 127), with NaNs handling
quantize_mx_with_e8m0_shared_exponent_simulated
Simulated MX quantization function inspired by torchao. It accepts high precision input tensor (data_hp), the MX scale shared exponent (shared_exponent), and returns the quantized tensor in the elem_dtype. The steps are: 1. normalize data_hp by performing a single-precision division with an MX scale in torch.float32 2. quantize the normalized data_hp to the target elem_dtype by using the native torhcao function
We call this implementation simulated because it is not an efficient hardware implementation
Arguments:
data_hp
torch.Tensor - The high precision input tensor, dtype is either torch.bfloat16 or torch.floatelem_dtype
dtypes.DType - The target element data type for quantization.shared_exponent
torch.Tensor - The E8M0 scale shared exponent (dtype=torch.uint8).orig_shape
torch.Size - The original shape of the input tensor, used to reshape the output tensor. Optional, defaults to None.
Returns:
torch.Tensor
dtype=torch.uint8 - The quantized tensor in the target lower precision format.
Raises:
AssertionError
- If the provided elem_dtype is not supported.
get_e8m0_shared_exponent
Computes the shared exponent for a given high-precision tensor.
Arguments:
data_hp
torch.Tensor - High-precision input tensor, with block size as the last dimension. dtype must be torch.bfloat16 or torch.float.elem_dtype
dtypes.DType - target element dtype
Returns:
torch.Tensor
- MX-scale exponent tensor as torch.uint8