torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[MXFP8] unable to run titan llama3 debug model with mxfp8. Assertion: n_rows % max_row_tile_size == 0

Open lessw2020 opened this issue 7 months ago • 4 comments

Bug description

This is specific to the debug_model. Llama3-8B works nicely.

running on B200: updated debug_model.toml to use converter = mx, and [mx] recipe_name = "mxfp8"

and compile = true.

during compile, errors out with following assert. Tested with both pip install torchao and install of latest torchao nightly.

result:

 traceback : Traceback (most recent call last):
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/data/users/less/torchtitan/torchtitan/train.py", line 420, in train
      self.train_step(inputs, labels)
    File "/data/users/less/torchtitan/torchtitan/train.py", line 364, in train_step
      loss.backward()
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_tensor.py", line 648, in backward
      torch.autograd.backward(
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/autograd/__init__.py", line 354, in backward
      _engine_run_backward(
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
      return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/autograd/function.py", line 308, in apply
      return user_fn(self, *args)
             ^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torchao/prototype/mx_formats/mx_linear.py", line 98, in backward
      weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1(
                                                  ^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 681, in __call__
      return self._opoverload(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_ops.py", line 806, in __call__
      return self._op(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/autograd.py", line 111, in autograd_impl
      result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
      result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_ops.py", line 811, in redispatch
      return self._handle.redispatch_boxed(keyset, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl
      result = self._backend_fns[device_type](*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
      return disable_fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 879, in _fn
      return fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn
      return fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_library/triton.py", line 115, in backend_fn
      return fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
    File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torchao/prototype/mx_formats/kernels.py", line 1333, in triton_to_mxfp8_dim1
      assert n_rows % max_row_tile_size == 0, "unsupported"
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  AssertionError: unsupported

Versions

torchao.version '0.12.0.dev20250516+cu128' torch.version '2.8.0.dev20250515+cu128'

lessw2020 avatar May 16 '25 04:05 lessw2020

debug model + mxfp8 works here: https://github.com/pytorch/torchtitan/pull/1015/files

My guess would be that you need to filter out the output layer from mxfp8 and then the debug model will work with mxfp8. Furthermore, I think torchtitan should always skip quantizing output regardless of user specified filter functions, because skipping the last layer from quantization is common practice and is likely to be needed for competitive accuracy.

vkuzo avatar May 16 '25 12:05 vkuzo

Verified that @vkuzo's comment is correct. I'm able to run the debug model if I filter the output layer:

NGPU=4 ./run_train.sh --model.print_after_conversion --training.compile --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8" --mx.filter_fqns "output"

syed-ahmed avatar May 20 '25 14:05 syed-ahmed

https://github.com/pytorch/torchtitan/pull/1208

drisspg avatar May 20 '25 15:05 drisspg

Thanks @vkuzo and @syed-ahmed . Driss has a PR now to auto filter the output. Last question here: a - do we need to also filter out the input layer as well for optimal results?

lessw2020 avatar May 20 '25 15:05 lessw2020

PR from Driss fixes, so let's close this. q re: input layer filtering can be discussed sep, model trains nicely as is.

lessw2020 avatar May 29 '25 22:05 lessw2020