torchmx.mx_tensor
Defines the tensor subclasses to represent the OCP MX-Format spec
Exponent E8M0 encoding details (OCP spec section 5.4.1):
- bias: 127
- supported exponent range: -127 to 127
- infinities: N/A
- NaN: 11111111
- Zeros: N/A
MXTensor Objects
to_dtype
Dequantize the MXTensor to the target_dtype.
Arguments:
target_dtypetorch.dtype - The target data type (torch.bfloat16, torch.float32, etc.) to which the MXTensor is dequantized.
Returns:
The dequantized tensor in the target_dtype.
Notes:
The MXTensor quantization is supported only for torch.bfloat16. But we allow the dequantization to either torch.bfloat16 or torch.float32.
Look at the quantize_mx and de_quantize_mx functions for more details.
to_mx
Convert/Quantize a high-precision tensor to MXTensor.
Arguments:
data_hptorch.Tensor - The high-precision input tensor. Onlytorch.bfloat16is supported. Look at thequantize_mxfunction for more details.elem_dtypedtypes.DType - The target element data type for quantization.block_sizeint - The block size. Default is 32.
Returns:
MXTensor- The quantized tensor in the target lower precision format.