[MXFP8] unable to run titan llama3 debug model with mxfp8. Assertion: n_rows % max_row_tile_size == 0
Bug description
This is specific to the debug_model. Llama3-8B works nicely.
running on B200: updated debug_model.toml to use converter = mx, and [mx] recipe_name = "mxfp8"
and compile = true.
during compile, errors out with following assert. Tested with both pip install torchao and install of latest torchao nightly.
result:
traceback : Traceback (most recent call last):
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/data/users/less/torchtitan/torchtitan/train.py", line 420, in train
self.train_step(inputs, labels)
File "/data/users/less/torchtitan/torchtitan/train.py", line 364, in train_step
loss.backward()
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_tensor.py", line 648, in backward
torch.autograd.backward(
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/autograd/__init__.py", line 354, in backward
_engine_run_backward(
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/autograd/function.py", line 308, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torchao/prototype/mx_formats/mx_linear.py", line 98, in backward
weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1(
^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 681, in __call__
return self._opoverload(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_ops.py", line 806, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/autograd.py", line 111, in autograd_impl
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_ops.py", line 811, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 879, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/triton.py", line 115, in backend_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torchao/prototype/mx_formats/kernels.py", line 1333, in triton_to_mxfp8_dim1
assert n_rows % max_row_tile_size == 0, "unsupported"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: unsupported
Versions
torchao.version '0.12.0.dev20250516+cu128' torch.version '2.8.0.dev20250515+cu128'
debug model + mxfp8 works here: https://github.com/pytorch/torchtitan/pull/1015/files
My guess would be that you need to filter out the output layer from mxfp8 and then the debug model will work with mxfp8. Furthermore, I think torchtitan should always skip quantizing output regardless of user specified filter functions, because skipping the last layer from quantization is common practice and is likely to be needed for competitive accuracy.
Verified that @vkuzo's comment is correct. I'm able to run the debug model if I filter the output layer:
NGPU=4 ./run_train.sh --model.print_after_conversion --training.compile --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8" --mx.filter_fqns "output"
https://github.com/pytorch/torchtitan/pull/1208
Thanks @vkuzo and @syed-ahmed . Driss has a PR now to auto filter the output. Last question here: a - do we need to also filter out the input layer as well for optimal results?
PR from Driss fixes, so let's close this. q re: input layer filtering can be discussed sep, model trains nicely as is.