Skip to content

Quantizing Linear Layers in a model

TorchMX also provides tools for quantizing individual layers and modules. Here's an example of how to quantize all the linear layers in the model. The following example demonstrates how to quantize a model with torch.nn.Linear layers to MX format using the MXInferenceLinear class. In this case the weights are quantized MX-fp6_e3m2 and the inputs to MX-fp8_e4m3

from torch import nn
from torchmx import MXTensor, dtypes
from torchmx.config import QLinearConfig, MXConfig
from torchmx.quant_api import quantize_linear_

# Create a high-precision model
model = nn.Sequential(
          nn.Linear(128, 256),
          nn.ReLU(),
          nn.Linear(256, 64),
          nn.ReLU()
        ).to(torch.bfloat16)

# Define the quantization configuration
qconfig = QLinearConfig(
    weights_config=MXConfig(elem_dtype_name="float6_e3m2", block_size=32),
    activations_config=MXConfig(elem_dtype_name="float8_e4m3", block_size=32),
)

# Quantize the model to MXFormat. Note this quantizes the model in place
quantize_linear_(model=model, qconfig=qconfig)


# Perform inference using the quantized model
x_hp = torch.randn(16, 128, dtype=torch.bfloat16)
y_mx = model(x_hp)