torchmx.mx_tensor
Defines the tensor subclasses to represent the OCP MX-Format spec
Exponent E8M0 encoding details (OCP spec section 5.4.1):
- bias: 127
- supported exponent range: -127 to 127
- infinities: N/A
- NaN: 11111111
- Zeros: N/A
quantize_mx
Takes a high precision tensor and converts to MX scale and raw data, in naive layout (scale and raw data are separate tensors). The function for now only supports quantization along the last dimension of the input tensor. For example, if the input tensor has shape (N, C, H, W) the output will be: - data_lp (torch.uint8) with shape (N, C, H, W) - scale (torch.uint8) with shape (N, C, H, W // block_size)
Arguments:
data_hp
torch.Tensor - high precision data tensor (dtype=torch.bfloat16)elem_dtype_name
str - target element dtype as a string to comply with torch.library.infer_schemablock_size
int - block size
Returns:
Tuple[torch.Tensor, torch.Tensor]: scale(biased), low precision data as tensors
_
Fake quantize_mx implementation.
This adds a “FakeTensor kernel” (aka “meta kernel”) to the operator. Given some FakeTensors inputs (dummy Tensors that don't have storage), this function returns dummy Tensors with the correct Tensor metadata(shape/strides/dtype/device). This is used by torch.compile to infer the shape and other metadata of the output tensors.
Reference: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial
dequantize_mx
Takes the low precision data and scale of MXTensor and converts to high precision
Arguments:
data_lp
torch.Tensor - low precision data tensorshared_exp_e8m0
torch.Tensor - biased exponent of the shared MX scale as torch.uint8elem_dtype_name
str - target element dtype's name as a string to comply with torch.library.infer_schemablock_size
int - block sizetarget_dtype
torch.dtype - target dtypeblock_dim
int - block dimension
Returns:
torch.Tensor
- high precision data tensor in target_dtype converted from MX
_
Fake dequantize_mx implementation.
This adds a “FakeTensor kernel” (aka “meta kernel”) to the operator. Given some FakeTensors inputs (dummy Tensors that don't have storage), this function returns dummy Tensors with the correct Tensor metadata(shape/strides/dtype/device). This is used by torch.compile to infer the shape and other metadata of the output tensors.
Reference: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial
ToMXConstrFunc Objects
Differentiable cast to MX, no-op in backward
forward
Forward method for the custom autograd function.
Arguments:
ctx
- The context object that can be used to stash information for backward computation.data_hp
torch.Tensor - The high-precision input tensor to be quantized.elem_dtype
dtypes.DType - The target data type for quantization.block_size
int - The block size used for quantization.padding
int - The padding size applied during quantization.
Returns:
MXTensor
- A custom tensor object containing the quantized data, scale factor, and metadata about the quantization process.
FromMXConstrFunc Objects
Differentiable cast from MX, no-op in backward
forward
Forward method for dequantizing a low-precision tensor to a target data type.
Arguments:
ctx
- The context object (not used in this implementation).tensor_lp
torch.Tensor - The low-precision tensor to be dequantized. This tensor is expected to have the following attributes:- _data: The raw data of the tensor.
- _scale_e8m0: The scale factor for dequantization.
- _elem_dtype.name: The name of the element data type.
- _block_size: The block size used in the tensor.
- _block_dim: The block dimensions of the tensor.
- _padding: The amount of padding applied to the tensor.
target_dtype
torch.dtype - The target data type to which the tensor should be dequantized.
Returns:
torch.Tensor
- The dequantized tensor, reshaped to its original shape if padding was involved.
NoopFwToMXBw Objects
Forward: no-op Backward: cast grad to MX
MXTensor Objects
__new__
Create a new instance of the tensor subclass.
Arguments:
cls
- The class being instantiated.scale_e8m0_bits
torch.Tensor - A tensor containing scale factors with dtype torch.uint8.data_bits
torch.Tensor - A tensor containing data bits.elem_dtype
dtypes.DType - The element data type.block_size
int - The block size.orig_dtype
torch.dtype - The original data type.block_dim
int - The block dimension. Default is None. If not set it defaults to the last dimension.padding
int - Padding size in case the block_dim is not multiple of the block_size Default is 0.
Returns:
An instance of the tensor subclass.
Raises:
AssertionError
- If the dtype of scale_e8m0_bits is not torch.uint8.AssertionError
- If the shape of scale_e8m0_bits is not 1-dimensional.AssertionError
- If the dtype of data_bits is not one of the supported types.AssertionError
- If elem_dtype is unsupported.
to_dtype
Dequantize the MXTensor to the target_dtype.
Arguments:
target_dtype
torch.dtype - The target data type (torch.bfloat16, torch.float32, etc.) to which the MXTensor is dequantized.
Returns:
The dequantized tensor in the target_dtype.
Notes:
The MXTensor quantization is supported only for torch.bfloat16
. But we allow the dequantization to either torch.bfloat16
or torch.float32
.
Look at the quantize_mx
and de_quantize_mx
functions for more details.
to_mx
Convert/Quantize a high-precision tensor to MXTensor.
Arguments:
data_hp
torch.Tensor - The high-precision input tensor. Onlytorch.bfloat16
is supported. Look at thequantize_mx
function for more details.elem_dtype
dtypes.DType - The target element data type for quantization.block_size
int - The block size. Default is 32.
Returns:
MXTensor
- The quantized tensor in the target lower precision format.