torchmx.mx_tensor
Defines the tensor subclasses to represent the OCP MX-Format spec
Exponent E8M0 encoding details (OCP spec section 5.4.1):
- bias: 127
- supported exponent range: -127 to 127
- infinities: N/A
- NaN: 11111111
- Zeros: N/A
MXTensor Objects
to_dtype
Dequantize the MXTensor to the target_dtype.
Arguments:
target_dtype
torch.dtype - The target data type (torch.bfloat16, torch.float32, etc.) to which the MXTensor is dequantized.
Returns:
The dequantized tensor in the target_dtype.
Notes:
The MXTensor quantization is supported only for torch.bfloat16
. But we allow the dequantization to either torch.bfloat16
or torch.float32
.
Look at the quantize_mx
and de_quantize_mx
functions for more details.
to_mx
Convert/Quantize a high-precision tensor to MXTensor.
Arguments:
data_hp
torch.Tensor - The high-precision input tensor. Onlytorch.bfloat16
is supported. Look at thequantize_mx
function for more details.elem_dtype
dtypes.DType - The target element data type for quantization.block_size
int - The block size. Default is 32.
Returns:
MXTensor
- The quantized tensor in the target lower precision format.