torchmx.custom_float_cast
hp_to_floatx
Converts high-precision floating-point data to a custom floating-point format, as specified by the number of exponent and mantissa bits.
Notes:
- This function does not take into account whether the data format supportes NaNs or infs. It will return the NaNs and Infs as found in the input tesnor
- It implemention OCP's 'saturating mode' the values to the max_normal value.
Arguments:
hp_data
torch.Tensor - Input tensor with high-precision floating-point data (float32 or float64).exponent_bits
int - Number of bits for the exponent in the target formatmantissa_bits
int - Number of bits for the mantissa in the target format.max_normal
float - Maximum representable normal value in the target format.round_mode
str, optional - Rounding mode to use. Options are "truncate" and "round_to_even". Default is "round_to_even".keep_subnormals
bool, optional - Whether to keep subnormal values. Default is True.
Returns:
torch.Tensor
- Tensor with data converted to the custom floating-point format.