Skip to content

torchmx.mx_tensor

Defines the tensor subclasses to represent the OCP MX-Format spec

Exponent E8M0 encoding details (OCP spec section 5.4.1):

  • bias: 127
  • supported exponent range: -127 to 127
  • infinities: N/A
  • NaN: 11111111
  • Zeros: N/A

MXTensor Objects

[view source]

class MXTensor(TorchAOBaseTensor)

to_dtype

[view source]

def to_dtype(target_dtype: torch.dtype) -> torch.Tensor

Dequantize the MXTensor to the target_dtype.

Arguments:

  • target_dtype torch.dtype - The target data type (torch.bfloat16, torch.float32, etc.) to which the MXTensor is dequantized.

Returns:

The dequantized tensor in the target_dtype.

Notes:

The MXTensor quantization is supported only for torch.bfloat16. But we allow the dequantization to either torch.bfloat16 or torch.float32. Look at the quantize_mx and de_quantize_mx functions for more details.

to_mx

[view source]

1
2
3
4
5
@staticmethod
@torch._dynamo.allow_in_graph
def to_mx(data_hp: torch.Tensor,
          elem_dtype: dtypes.DType,
          block_size: int = 32) -> "MXTensor"

Convert/Quantize a high-precision tensor to MXTensor.

Arguments:

  • data_hp torch.Tensor - The high-precision input tensor. Only torch.bfloat16 is supported. Look at the quantize_mx function for more details.
  • elem_dtype dtypes.DType - The target element data type for quantization.
  • block_size int - The block size. Default is 32.

Returns:

  • MXTensor - The quantized tensor in the target lower precision format.