[BUG]No registered hparams found for BF16 on Sm90 (H100/H800) using GemmGroupedV3AGScatter
Describe the bug When running the examples/moe_layer0.py example script on H100 hardware (arch=Sm90), the program crashes at the op.forward() call.
The root cause appears to be that the script defaults to the bfloat16 dtype (L240), which causes the Flux library to look up a GemmMeta configuration that does not have a corresponding set of pre-tuned kernel parameters (hparams) in the OpRegistry.
Expected Behavior: The examples/moe_layer0.py script should run successfully on H100 hardware with its default configuration (--dtype bfloat16).
Actual Behavior: The program crashes with a RuntimeError stating "No registered hparams found".
To Reproduce On a machine with H100 GPUs and InfiniBand drivers present, clone and build the flux library.
Navigate to the examples/ directory.
Note: To work around a separate NCCL initialization bug (see "Additional Context" below), the NCCL_IB_DISABLE=1 environment variable must be set.
export NCCL_IB_DISABLE=1
Execute the run_moe.sh script, which launches moe_layer0.py using its default bfloat16 setting:
bash run_moe.sh
The program crashes shortly after printing after flux_shm initialization.
Expected behavior A clear and concise description of what you expected to happen.
Stack trace/logs If applicable, add the stack trace or logs from the time of the error.
Environment Hardware: 2x NVIDIA H100
CUDA Version: 12.4
PyTorch Version: 2.6.0+cu124
NCCL Version: 2.21.5+cuda12.4
Python Version: 3.11
Flux Version: 1.1.2 (built from source)
Proposed fix If you have a proposal for how to fix the issue state it here or link to a PR.
Additional context The following RuntimeError is thrown by torchrun:
Plaintext
/root/zyhuang/temp_can/flux/include/flux/op_registry.h:206 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=BF16,b=BF16,c=BF16,d=BF16,acc=FP32,blockscale=FP32),arch=Sm90,sm_core=H800,comm_op=AGScatter,gemm_layout=RCR,impl=GemmGroupedV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=None)
[rank0]: Traceback (most recent call last):
[rank0]: File "/root/zyhuang/temp_can/flux/examples/moe_layer0.py", line 314, in
torchrun ... moe_layer0.py --dtype float16 This confirms that the hparams for float16 exist, but those for bfloat16 are missing for this specific configuration.
i have met the same problem,do u solved it ?
maybe you can clean the build directory and recompile and run again?
maybe you can clean the build directory and recompile and run again?
Still getting the error, my GPU is H100, do I need to modify anything else?