Skip to content

torchmx.config

MXConfig Objects

[view source]

@dataclass(frozen=True)
class MXConfig(_BaseConfig)

Configuration class for MX Quantization

Arguments:

  • elem_dtype_name str - The name of the element dtype. Look at the name attribute in dtypes.py for supported strings.
  • block_size int - The block size. Default 32

Notes:

Pass either elem_dtype or elem_dtype_name and not both.

Methods:

  • __post_init__() - Validates the configuration parameters after initialization.

elem_dtype

[view source]

@property
def elem_dtype() -> dtypes.DType

Get the DType object corresponding to elem_dtype_name.

Returns:

  • dtypes.DType - The corresponding dtypes.DType object

load_from_dict

[view source]

@classmethod
def load_from_dict(cls, config_dict: dict) -> "MXConfig"

Load the configuration from a dictionary.

Arguments:

  • config_dict dict - The configuration dictionary.

Returns:

  • MXConfig - The configuration object.

to_dict

[view source]

def to_dict() -> dict

Convert the configuration to a dictionary.

Returns:

  • dict - The configuration dictionary.

QLinearConfig Objects

[view source]

@dataclass(frozen=True)
class QLinearConfig(_BaseConfig)

Linear layer Quantization Configuration

Arguments:

  • weights_config MXConfig - Configuration for the weights
  • activations_config MXConfig - Configuration for the activations

load_from_dict

[view source]

@classmethod
def load_from_dict(cls, config_dict: dict) -> "QLinearConfig"

Load the configuration from a dictionary.

Arguments:

  • config_dict dict - The configuration dictionary.

Returns:

  • QLinearConfig - The configuration object.

to_dict

[view source]

def to_dict() -> dict

Convert the configuration to a dictionary.

Returns:

  • dict - The configuration dictionary.

QAttentionConfig Objects

[view source]

@dataclass(frozen=True)
class QAttentionConfig(_BaseConfig)

Attention layer Quantization Configuration

Arguments:

  • projection_config QLinearConfig - Configuration for the projection layers. Generally q,k,v,o projection layers.
  • query_config Optional[MXConfig] - Configuration for the query tensor. Default None
  • key_config Optional[MXConfig] - Configuration for the key tensor. Default None
  • value_config Optional[MXConfig] - Configuration for the value tensor. Default None
  • attention_weights_config Optional[MXConfig] - Configuration for the attention weights which is the output of torch.matmul(q,k.T) operation. Default None

is_qkv_quantization_enabled

[view source]

@property
def is_qkv_quantization_enabled() -> bool

Check if q,k,v and attention_weights quantization is enabled.

Returns:

  • bool - True if q,k,v and attention_weights quantization is enabled else False

load_from_dict

[view source]

@classmethod
def load_from_dict(cls, config_dict: dict) -> "QAttentionConfig"

Load the configuration from a dictionary.

Arguments:

  • config_dict dict - The configuration dictionary.

Returns:

  • QAttentionConfig - The configuration object.

to_dict

[view source]

def to_dict()

Convert the configuration to a dictionary.

Returns:

  • dict - The configuration dictionary.