torchmx.mx_tensor
Defines the tensor subclasses to represent the OCP MX-Format spec
Exponent E8M0 encoding details (OCP spec section 5.4.1):
- bias: 127
- supported exponent range: -127 to 127
- infinities: N/A
- NaN: 11111111
- Zeros: N/A
quantize_mx
Takes a high precision tensor and converts to MX scale and raw data, in naive layout (scale and raw data are separate tensors). The function for now only supports quantization along the last dimension of the input tensor. For example, if the input tensor has shape (N, C, H, W) the output will be: - data_lp (torch.uint8) with shape (N, C, H, W) - scale (torch.uint8) with shape (N, C, H, W // block_size)
Arguments:
data_hptorch.Tensor - high precision data tensor (dtype=torch.bfloat16)elem_dtype_namestr - target element dtype as a string to comply with torch.library.infer_schemablock_sizeint - block size
Returns:
Tuple[torch.Tensor, torch.Tensor]: scale(biased), low precision data as tensors
_
Fake quantize_mx implementation.
This adds a “FakeTensor kernel” (aka “meta kernel”) to the operator. Given some FakeTensors inputs (dummy Tensors that don't have storage), this function returns dummy Tensors with the correct Tensor metadata(shape/strides/dtype/device). This is used by torch.compile to infer the shape and other metadata of the output tensors.
Reference: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial
dequantize_mx
Takes the low precision data and scale of MXTensor and converts to high precision
Arguments:
data_lptorch.Tensor - low precision data tensorshared_exp_e8m0torch.Tensor - biased exponent of the shared MX scale as torch.uint8elem_dtype_namestr - target element dtype's name as a string to comply with torch.library.infer_schemablock_sizeint - block sizetarget_dtypetorch.dtype - target dtypeblock_dimint - block dimension
Returns:
torch.Tensor- high precision data tensor in target_dtype converted from MX
_
Fake dequantize_mx implementation.
This adds a “FakeTensor kernel” (aka “meta kernel”) to the operator. Given some FakeTensors inputs (dummy Tensors that don't have storage), this function returns dummy Tensors with the correct Tensor metadata(shape/strides/dtype/device). This is used by torch.compile to infer the shape and other metadata of the output tensors.
Reference: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial
ToMXConstrFunc Objects
Differentiable cast to MX, no-op in backward
forward
Forward method for the custom autograd function.
Arguments:
ctx- The context object that can be used to stash information for backward computation.data_hptorch.Tensor - The high-precision input tensor to be quantized.elem_dtypedtypes.DType - The target data type for quantization.block_sizeint - The block size used for quantization.paddingint - The padding size applied during quantization.
Returns:
MXTensor- A custom tensor object containing the quantized data, scale factor, and metadata about the quantization process.
FromMXConstrFunc Objects
Differentiable cast from MX, no-op in backward
forward
Forward method for dequantizing a low-precision tensor to a target data type.
Arguments:
ctx- The context object (not used in this implementation).tensor_lptorch.Tensor - The low-precision tensor to be dequantized. This tensor is expected to have the following attributes:- _data: The raw data of the tensor.
- _scale_e8m0: The scale factor for dequantization.
- _elem_dtype.name: The name of the element data type.
- _block_size: The block size used in the tensor.
- _block_dim: The block dimensions of the tensor.
- _padding: The amount of padding applied to the tensor.
target_dtypetorch.dtype - The target data type to which the tensor should be dequantized.
Returns:
torch.Tensor- The dequantized tensor, reshaped to its original shape if padding was involved.
NoopFwToMXBw Objects
Forward: no-op Backward: cast grad to MX
MXTensor Objects
__new__
Create a new instance of the tensor subclass.
Arguments:
cls- The class being instantiated.scale_e8m0_bitstorch.Tensor - A tensor containing scale factors with dtype torch.uint8.data_bitstorch.Tensor - A tensor containing data bits.elem_dtypedtypes.DType - The element data type.block_sizeint - The block size.orig_dtypetorch.dtype - The original data type.block_dimint - The block dimension. Default is None. If not set it defaults to the last dimension.paddingint - Padding size in case the block_dim is not multiple of the block_size Default is 0.
Returns:
An instance of the tensor subclass.
Raises:
AssertionError- If the dtype of scale_e8m0_bits is not torch.uint8.AssertionError- If the shape of scale_e8m0_bits is not 1-dimensional.AssertionError- If the dtype of data_bits is not one of the supported types.AssertionError- If elem_dtype is unsupported.
to_dtype
Dequantize the MXTensor to the target_dtype.
Arguments:
target_dtypetorch.dtype - The target data type (torch.bfloat16, torch.float32, etc.) to which the MXTensor is dequantized.
Returns:
The dequantized tensor in the target_dtype.
Notes:
The MXTensor quantization is supported only for torch.bfloat16. But we allow the dequantization to either torch.bfloat16 or torch.float32.
Look at the quantize_mx and de_quantize_mx functions for more details.
to_mx
Convert/Quantize a high-precision tensor to MXTensor.
Arguments:
data_hptorch.Tensor - The high-precision input tensor. Onlytorch.bfloat16is supported. Look at thequantize_mxfunction for more details.elem_dtypedtypes.DType - The target element data type for quantization.block_sizeint - The block size. Default is 32.
Returns:
MXTensor- The quantized tensor in the target lower precision format.