grouped_mm illegal memory access when the input tensor size > 16*4096*8*4096 for llama4
Bug description
Summary
when training with torchtitan, If the torch._grouped_mm input tensor size is big enough, torch.AcceleratorError: CUDA error: an illegal memory access was encountered error occurs.
Specifically when training llm with torchtitan, I set
- local batch size = 16
- sequence length = 4096
- moe num activated experts = 8
- hidden dim = 4096
then after some padding applied by torchtitan, the shape of the input tensor to torch._grouped_mm is [525312, 4096] and the error occurs. It works fine when I decrease any of these 4 factors(local batch. size, sequence length, moe num activated experts, hidden dim).
From what I remember, this error wasn't present before update around late september.
This error happens even when n_layers=1.
Stacktrace
Root Cause (first observed failure):
[0]:
time : 2025-10-24_18:38:53
host : Slurm-GPU-Node-44
rank : 5 (local_rank: 5)
exitcode : 1 (pid: 3608571)
error_file: /tmp/torchelastic_gqd1bzn0/none_uf9e6gfx/attempt_0/5/error.json
traceback : Traceback (most recent call last):
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/data/fsx/private/yoonsoo/pr/torchtitan/torchtitan/train.py", line 612, in train
self.train_step(data_iterator)
File "/data/fsx/private/yoonsoo/pr/torchtitan/torchtitan/train.py", line 512, in train_step
loss = self.forward_backward_step(input_dict, labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/fsx/private/yoonsoo/pr/torchtitan/torchtitan/train.py", line 488, in forward_backward_step
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1886, in _call_impl
return inner()
^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1834, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/fsx/private/yoonsoo/pr/torchtitan/torchtitan/models/llama4/model/model.py", line 560, in forward
h = layer(h, self.freqs_cis, attention_masks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 433, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1886, in _call_impl
return inner()
^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1834, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 912, in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1791, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 145, in forward
def forward(self, *args, **kwargs):
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1129, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1139, in forward
return compiled_fn(full_args)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 343, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 133, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 107, in g
return f(*args)
^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/autograd/function.py", line 583, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2152, in forward
fw_outs = call_func_at_runtime_with_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 133, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 531, in wrapper
return compiled_fn(runtime_args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 695, in inner_fn
unwrapped_outs = compiled_fn(unwrapped_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 729, in inner_fn
outs = compiled_fn(args)
^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 618, in __call__
return self.current_callable(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/utils.py", line 3059, in run
out = model(new_inputs)
^^^^^^^^^^^^^^^^^
File "/tmp/torchinductor_yoonsoo/dl/cdlwx2fwmcdscf2xtltwynghxgi2parcmk5iam23dqpnnrd6ifv6.py", line 2145, in call
triton_poi_fused_index_put_18.run(buf61, buf68, buf69, 3225419776, stream=stream5)
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1335, in run
self.autotune_to_one_config(*args, **kwargs)
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1097, in autotune_to_one_config
timings = self.benchmark_all_configs(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1060, in benchmark_all_configs
launcher: self.bench(launcher, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 922, in bench
return benchmarker.benchmark_gpu(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 89, in wrapper
return fn(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 320, in benchmark_gpu
torch.cuda.synchronize()
File "/home/yoonsoo/miniconda3/envs/torchtitan-tmp/lib/python3.12/site-packages/torch/cuda/__init__.py", line 1094, in synchronize
return torch._C._cuda_synchronize()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Versions
How to reproduce
1. Install
conda create -y -n torchtitan-tmp python=3.12
conda activate torchtitan-tmp
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 --force-reinstall
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-4-Scout-17B-16E --assets tokenizer --hf_token=...
2. Create torchtitan config
locate following train config at torchtitan/models/llama4/train_configs/tmp.toml
[job]
dump_folder = "./outputs"
description = "Llama 4 Scout 17Bx16E training"
[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100
[metrics]
log_freq = 10
enable_tensorboard = false
save_tb_folder = "tb"
[model]
name = "llama4"
flavor = "tmp"
hf_assets_path = "./assets/hf/Llama-4-Scout-17B-16E"
# converters = ["quantize.linear.float8"]
[optimizer]
name = "AdamW"
lr = 4e-3
eps = 1e-15
[lr_scheduler]
warmup_steps = 600
min_lr_factor = 0.1
[training]
local_batch_size = 16
seq_len = 4096
max_norm = 1.0 # grad norm clipping
steps = 3000
dataset = "c4"
[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1
expert_parallel_degree = 1
expert_tensor_parallel_degree = 1
[checkpoint]
enable = false
folder = "checkpoint"
interval = 500
last_save_model_only = true
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint]
mode = "full" # ["none", "selective", "full"]
[compile]
enable=true
components = ["model", "loss"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
filter_fqns = ["output", "router.gate"]
[quantize.linear.mx]
filter_fqns = ["output", "router.gate"]
3. Add args
add following args to torchtitan/models/llama4/__init__.py
"tmp": TransformerModelArgs(
dim=4096,
n_layers=1,
n_heads=40,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=2048,
rope_theta=500000,
rope_scaling_args=RoPEScalingArgs(),
max_seq_len=10485760,
moe_args=MoEArgs(num_experts=128, top_k=8),
interleave_moe_layer_step=1,
),
4. Run
torchrun --nproc_per_node=8 --rdzv_backend=c10d -m torchtitan.train --job.config_file torchtitan/models/llama4/train_configs/tmp.toml
Env
Collecting environment information...
PyTorch version: 2.10.0.dev20251023+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-1029-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA H200
GPU 1: NVIDIA H200
GPU 2: NVIDIA H200
GPU 3: NVIDIA H200
GPU 4: NVIDIA H200
GPU 5: NVIDIA H200
GPU 6: NVIDIA H200
GPU 7: NVIDIA H200
Nvidia driver version: 570.158.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.12.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.12.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7R13 Processor
CPU family: 25
Model: 1
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
Stepping: 1
BogoMIPS: 5299.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 48 MiB (96 instances)
L3 cache: 384 MiB (12 instances)
NUMA node(s): 4
NUMA node0 CPU(s): 0-23
NUMA node1 CPU(s): 24-47
NUMA node2 CPU(s): 48-71
NUMA node3 CPU(s): 72-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==2.3.4
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pytorch-triton==3.5.0+git7416ffcb
[pip3] torch==2.10.0.dev20251023+cu128
[pip3] torchao==0.15.0.dev20251024+cu128
[pip3] torchdata==0.11.0
[pip3] triton==3.5.0
[conda] numpy 2.3.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] pytorch-triton 3.5.0+git7416ffcb pypi_0 pypi
[conda] torch 2.10.0.dev20251023+cu128 pypi_0 pypi
[conda] torchao 0.15.0.dev20251024+cu128 pypi_0 pypi
[conda] torchdata 0.11.0 pypi_0 pypi
[conda] triton 3.5.0 pypi_0 pypi
then after some padding applied by torchtitan, the shape of the input tensor to torch._grouped_mm is [525312, 4096] and the error occurs
If you think the error is from torch._grouped_mm regarding input shapes, could you provide a minimal unit test?
I found that the error was not present when minimally doing forward/backward pass on torch._grouped_mm with specified tensor shapes. The error occured when using torchtitan.
Also when I chunked the moe operation like below, the error didn't occur.
# torchtitan/models/llama4/model/model.py TransformerBlock forward
if self.moe_enabled:
h_normed = self.ffn_norm(h)
out = h + torch.cat([self.moe(h_normed[:8]), self.moe(h_normed[8:])], dim=0) # chunk batch size 16 to 8, 8
# out = h + self.moe(self.ffn_norm(h))
Now that I think about it, specific grouped_mm might not be the cause, but some operation inside moe block that depends on all of 4 factors (local batch size, sequence length, moe num activated experts, hidden dim) is the cause.
seems like int64 indexing
seems like int64 indexing
@awgu Thanks! Could you explain a bit more?
@yoonniverse It seems your initial error log is with compile. Does it happen without compile?
It works fine without compile. (To avoid oom that happens on ce loss on single node without compile, I set seq_len=1024 & top_k=32. In this setting, compile->error, no compile->no error)
cc @xmfan