torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Failing to generate MLIR for Llama3 using TorchMLIR

Open HemKava opened this issue 6 months ago • 1 comments

I downloaded Llama3 model to hf-files directory and then trying to use AutoModelForCausalLM to load the model, and then convert the transformer portion to MLIR.

huggingface-cli download meta-llama/Meta-Llama-3-8B --local-dir ./hf-files

import torch
from transformers import AutoModelForCausalLM
import torch_mlir.fx as fx
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils._mode_utils import no_dispatch

print("Loading model...")
full = AutoModelForCausalLM.from_pretrained(
    "./hf-files",
    local_files_only=True,
    torch_dtype="auto"
).eval()

core = full.model
print("Model loaded")

# Setup FakeTensorMode
fake_mode = FakeTensorMode()

with fake_mode, no_dispatch():
    dummy = torch.randint(0, core.config.vocab_size, (1, 16), dtype=torch.long)
    print("Fake dummy input created")

    # fx.export_and_import will internally call torch.export.export
    mlir_mod = fx.export_and_import(core, dummy)
    print("MLIR module created")

with open("llama3_transformer.mlir", "w") as f:
    f.write(str(mlir_mod))
print("Saved to llama3_transformer.mlir")

I am seeing following AssertionError error in generating the MLIR. Any pointers will be helpful:

Traceback (most recent call last): File "hf-to-mlir3.py", line 25, in mlir_mod = fx.export_and_import(core, dummy) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/torch-mlir/build/python_packages/torch_mlir/torch_mlir/fx.py", line 98, in export_and_import prog = torch.export.export( ^^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/export/init.py", line 319, in export raise e File "/venv/lib64/python3.11/site-packages/torch/export/init.py", line 286, in export return _export( ^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1159, in wrapper raise e File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1125, in wrapper ep = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 2172, in _export ep = _export_for_training( ^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1159, in wrapper raise e File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1125, in wrapper ep = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 2033, in _export_for_training export_artifact = export_func( ^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1933, in _non_strict_export ) = make_fake_inputs( ^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 347, in make_fake_inputs fake_args, fake_kwargs = tree_map_with_path( ^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 2077, in tree_map_with_path return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 1197, in unflatten leaves = list(leaves) ^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 2077, in return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 348, in lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 162, in fakify fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2943, in from_tensor return self.fake_tensor_converter.from_real_tensor( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 399, in from_real_tensor out = self.meta_converter( ^^^^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 1913, in call r = self.meta_tensor( ^^^^^^^^^^^^^^^^^ File "/venv/lib64/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 894, in meta_tensor assert not torch._C._dispatch_tls_local_exclude_set().has( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AssertionError

HemKava avatar Jun 25 '25 22:06 HemKava

same question! may I ask that have you resolved this question?

YilanWang avatar Nov 14 '25 09:11 YilanWang