Skip to content

torchmx.ops

All the custom ops for the TorchMX package are defined here. This also includes the aten ops that are implemented for the MXTensor class.

mx_cast_up_op

[view source]

@implements([aten.sum.dim_IntList])
def mx_cast_up_op(aten_op, types, args, kwargs=None)

Be careful with this function, this is a "fallback" op that casts the output of the op to the original precision. And performs the op.

We currently need this to support the backward for admmm bias. "addmm" -> out "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"

mx_view_op

[view source]

@implements([aten.view.default])
def mx_view_op(aten_op, types, args, kwargs=None)

This is a custom op that is used to handle the view op for MXTensor. The user is not expected to call this op directly. This Op is only implemented to support some internal PyTorch functions. We only supports view op in the case following cases: - When the block dim is the last dim - This is needed for aten.linear - When the block dim is the second last dim: - The tensor must be 4D, else raises Assertion error - This is needed for the following 4D matmuls in attention: - torch.matmul(query_states, key_states.transpose(2, 3)) - torch.matmul(attn_weights, value_states)

Raises:

In all the other cases we raise ValueError

autocast_to_copy

[view source]

@implements([aten._to_copy.default])
def autocast_to_copy(aten_op, types, args, kwargs=None)

This gets called when running matmul under autocast when the input is a MXTensor, presenting as a fp32 tensor.