Skip to content

torchmx.custom_float_cast

hp_to_floatx

[view source]

1
2
3
4
5
6
def hp_to_floatx(hp_data: torch.Tensor,
                 exponent_bits: int,
                 mantissa_bits: int,
                 max_normal: float,
                 round_mode="round_to_even",
                 keep_subnormals: bool = True)

Converts high-precision floating-point data to a custom floating-point format, as specified by the number of exponent and mantissa bits.

Notes:

  • This function does not take into account whether the data format supportes NaNs or infs. It will return the NaNs and Infs as found in the input tesnor
  • It implemention OCP's 'saturating mode' the values to the max_normal value.

Arguments:

  • hp_data torch.Tensor - Input tensor with high-precision floating-point data (float32 or float64).
  • exponent_bits int - Number of bits for the exponent in the target format
  • mantissa_bits int - Number of bits for the mantissa in the target format.
  • max_normal float - Maximum representable normal value in the target format.
  • round_mode str, optional - Rounding mode to use. Options are "truncate" and "round_to_even". Default is "round_to_even".
  • keep_subnormals bool, optional - Whether to keep subnormal values. Default is True.

Returns:

  • torch.Tensor - Tensor with data converted to the custom floating-point format.