Skip to content

torchmx.mx_tensor

Defines the tensor subclasses to represent the OCP MX-Format spec

Exponent E8M0 encoding details (OCP spec section 5.4.1):

  • bias: 127
  • supported exponent range: -127 to 127
  • infinities: N/A
  • NaN: 11111111
  • Zeros: N/A

quantize_mx

[view source]

1
2
3
@torch.library.custom_op("torchmx::quantize_mx", mutates_args=())
def quantize_mx(data_hp: torch.Tensor, elem_dtype_name: str,
                block_size: int) -> Tuple[torch.Tensor, torch.Tensor]

Takes a high precision tensor and converts to MX scale and raw data, in naive layout (scale and raw data are separate tensors). The function for now only supports quantization along the last dimension of the input tensor. For example, if the input tensor has shape (N, C, H, W) the output will be: - data_lp (torch.uint8) with shape (N, C, H, W) - scale (torch.uint8) with shape (N, C, H, W // block_size)

Arguments:

  • data_hp torch.Tensor - high precision data tensor (dtype=torch.bfloat16)
  • elem_dtype_name str - target element dtype as a string to comply with torch.library.infer_schema
  • block_size int - block size

Returns:

Tuple[torch.Tensor, torch.Tensor]: scale(biased), low precision data as tensors

_

[view source]

1
2
3
@quantize_mx.register_fake
def _(data_hp: torch.Tensor, elem_dtype_name: str,
      block_size: int) -> Tuple[torch.Tensor, torch.Tensor]

Fake quantize_mx implementation.

This adds a “FakeTensor kernel” (aka “meta kernel”) to the operator. Given some FakeTensors inputs (dummy Tensors that don't have storage), this function returns dummy Tensors with the correct Tensor metadata(shape/strides/dtype/device). This is used by torch.compile to infer the shape and other metadata of the output tensors.

Reference: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial

dequantize_mx

[view source]

1
2
3
4
@torch.library.custom_op("torchmx::dequantize_mx", mutates_args=())
def dequantize_mx(data_lp: torch.Tensor, shared_exp_e8m0: torch.Tensor,
                  elem_dtype_name: str, block_size: int,
                  target_dtype: torch.dtype, block_dim: int) -> torch.Tensor

Takes the low precision data and scale of MXTensor and converts to high precision

Arguments:

  • data_lp torch.Tensor - low precision data tensor
  • shared_exp_e8m0 torch.Tensor - biased exponent of the shared MX scale as torch.uint8
  • elem_dtype_name str - target element dtype's name as a string to comply with torch.library.infer_schema
  • block_size int - block size
  • target_dtype torch.dtype - target dtype
  • block_dim int - block dimension

Returns:

  • torch.Tensor - high precision data tensor in target_dtype converted from MX

_

[view source]

1
2
3
4
@dequantize_mx.register_fake
def _(data_lp: torch.Tensor, shared_exp_e8m0: torch.Tensor,
      elem_dtype_name: str, block_size: int, target_dtype: torch.dtype,
      block_dim: int) -> torch.Tensor

Fake dequantize_mx implementation.

This adds a “FakeTensor kernel” (aka “meta kernel”) to the operator. Given some FakeTensors inputs (dummy Tensors that don't have storage), this function returns dummy Tensors with the correct Tensor metadata(shape/strides/dtype/device). This is used by torch.compile to infer the shape and other metadata of the output tensors.

Reference: https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial

ToMXConstrFunc Objects

[view source]

@torch._dynamo.allow_in_graph
class ToMXConstrFunc(torch.autograd.Function)

Differentiable cast to MX, no-op in backward

forward

[view source]

1
2
3
@staticmethod
def forward(ctx, data_hp: torch.Tensor, elem_dtype: dtypes.DType,
            block_size: int)

Forward method for the custom autograd function.

Arguments:

  • ctx - The context object that can be used to stash information for backward computation.
  • data_hp torch.Tensor - The high-precision input tensor to be quantized.
  • elem_dtype dtypes.DType - The target data type for quantization.
  • block_size int - The block size used for quantization.
  • padding int - The padding size applied during quantization.

Returns:

  • MXTensor - A custom tensor object containing the quantized data, scale factor, and metadata about the quantization process.

FromMXConstrFunc Objects

[view source]

@torch._dynamo.allow_in_graph
class FromMXConstrFunc(torch.autograd.Function)

Differentiable cast from MX, no-op in backward

forward

[view source]

1
2
3
@staticmethod
def forward(ctx, tensor_lp: torch.Tensor,
            target_dtype: torch.dtype) -> torch.Tensor

Forward method for dequantizing a low-precision tensor to a target data type.

Arguments:

  • ctx - The context object (not used in this implementation).
  • tensor_lp torch.Tensor - The low-precision tensor to be dequantized. This tensor is expected to have the following attributes:
  • _data: The raw data of the tensor.
  • _scale_e8m0: The scale factor for dequantization.
  • _elem_dtype.name: The name of the element data type.
  • _block_size: The block size used in the tensor.
  • _block_dim: The block dimensions of the tensor.
  • _padding: The amount of padding applied to the tensor.
  • target_dtype torch.dtype - The target data type to which the tensor should be dequantized.

Returns:

  • torch.Tensor - The dequantized tensor, reshaped to its original shape if padding was involved.

NoopFwToMXBw Objects

[view source]

@torch._dynamo.allow_in_graph
class NoopFwToMXBw(torch.autograd.Function)

Forward: no-op Backward: cast grad to MX

MXTensor Objects

[view source]

class MXTensor(TorchAOBaseTensor)

__new__

[view source]

1
2
3
4
5
6
7
8
def __new__(cls,
            scale_e8m0_bits: torch.Tensor,
            data_bits: torch.Tensor,
            elem_dtype: dtypes.DType,
            block_size: int,
            orig_dtype: torch.dtype,
            padding: int = 0,
            block_dim: Optional[int] = None)

Create a new instance of the tensor subclass.

Arguments:

  • cls - The class being instantiated.
  • scale_e8m0_bits torch.Tensor - A tensor containing scale factors with dtype torch.uint8.
  • data_bits torch.Tensor - A tensor containing data bits.
  • elem_dtype dtypes.DType - The element data type.
  • block_size int - The block size.
  • orig_dtype torch.dtype - The original data type.
  • block_dim int - The block dimension. Default is None. If not set it defaults to the last dimension.
  • padding int - Padding size in case the block_dim is not multiple of the block_size Default is 0.

Returns:

An instance of the tensor subclass.

Raises:

  • AssertionError - If the dtype of scale_e8m0_bits is not torch.uint8.
  • AssertionError - If the shape of scale_e8m0_bits is not 1-dimensional.
  • AssertionError - If the dtype of data_bits is not one of the supported types.
  • AssertionError - If elem_dtype is unsupported.

to_dtype

[view source]

def to_dtype(target_dtype: torch.dtype) -> torch.Tensor

Dequantize the MXTensor to the target_dtype.

Arguments:

  • target_dtype torch.dtype - The target data type (torch.bfloat16, torch.float32, etc.) to which the MXTensor is dequantized.

Returns:

The dequantized tensor in the target_dtype.

Notes:

The MXTensor quantization is supported only for torch.bfloat16. But we allow the dequantization to either torch.bfloat16 or torch.float32. Look at the quantize_mx and de_quantize_mx functions for more details.

to_mx

[view source]

1
2
3
4
5
@staticmethod
@torch._dynamo.allow_in_graph
def to_mx(data_hp: torch.Tensor,
          elem_dtype: dtypes.DType,
          block_size: int = 32) -> "MXTensor"

Convert/Quantize a high-precision tensor to MXTensor.

Arguments:

  • data_hp torch.Tensor - The high-precision input tensor. Only torch.bfloat16 is supported. Look at the quantize_mx function for more details.
  • elem_dtype dtypes.DType - The target element data type for quantization.
  • block_size int - The block size. Default is 32.

Returns:

  • MXTensor - The quantized tensor in the target lower precision format.