[QUESTION] Not supported on A6000?
Your question Hi,
When I run the test demo with a node consists of 2 A6000 it reports bugs:
RuntimeError: /root/opensource/flux/src/cuda/op_registry.cu:36 Check failed: arch_num == 80 || arch_num == 89 || arch_num == 90. unsupported arch: 86
So flux can only support these three GPUs (cc=90, 80, 89), correct me if I misunderstand it.
Thanks
Yes, Flux only compiled the architectures 80, 89, and 90 for now. However, I suspect that CUTLASS v2 should directly support architecture number 86. Could you try adding the corresponding arch number and recompiling to see if it works?
Thanks,
I add 86 arguments to /flux/src/cuda/op_registery.cu line 36 like this:
void
init_arch_tag() {
int major, minor;
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
int arch_num = major * 10 + minor;
FLUX_CHECK(arch_num == 80 || arch_num == 89 || arch_num == 90 || arch_num == 86)
<< "unsupported arch: " << arch_num;
arch = ArchEnum{arch_num};
}
}
I recompiled it via Build from Source again, but when running ./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10 it turns out:
RuntimeError: ~/flux/include/flux/op_registry.h:220 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32),arch=UNK,comm_op=CommNone,gemm_layout=RCR,impl=GemmV2,impl_spec=GemmV2Meta(fast_accum=0),comm_spec=None)
The corresponding code pieces in op_registry.h:
// Iterate all hparams registered for a meta and call func.
// This can be useful for tuning.
template <class... Ts>
void
visit_hparams(std::function<void(UnifiedGemmHParams)> &&func, GemmMeta<Ts...> meta) {
std::shared_lock<std::shared_mutex> lock(register_mutex_);
auto unified_meta = unify_type(meta);
auto iter = gemm_hparams_.find(unified_meta);
FLUX_CHECK(iter != gemm_hparams_.end()) << "no op registered for meta:" << meta;
for (const auto &hparams_pair : iter->second) {
auto const &hparams = hparams_pair.second;
func(hparams);
}
}
I have not yet took a deep look at what is hparams is, if possible, can you please point it to me quickly? Any additional changes in the codebase? Thanks!
Thanks,
I add 86 arguments to
/flux/src/cuda/op_registery.culine 36 like this:void init_arch_tag() { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0); cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); int arch_num = major * 10 + minor; FLUX_CHECK(arch_num == 80 || arch_num == 89 || arch_num == 90 || arch_num == 86) << "unsupported arch: " << arch_num; arch = ArchEnum{arch_num}; } }I recompiled it via Build from Source again, but when running
./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10it turns out:RuntimeError: ~/flux/include/flux/op_registry.h:220 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32),arch=UNK,comm_op=CommNone,gemm_layout=RCR,impl=GemmV2,impl_spec=GemmV2Meta(fast_accum=0),comm_spec=None)The corresponding code pieces in
op_registry.h:// Iterate all hparams registered for a meta and call func. // This can be useful for tuning. template <class... Ts> void visit_hparams(std::function<void(UnifiedGemmHParams)> &&func, GemmMeta<Ts...> meta) { std::shared_lock<std::shared_mutex> lock(register_mutex_); auto unified_meta = unify_type(meta); auto iter = gemm_hparams_.find(unified_meta); FLUX_CHECK(iter != gemm_hparams_.end()) << "no op registered for meta:" << meta; for (const auto &hparams_pair : iter->second) { auto const &hparams = hparams_pair.second; func(hparams); } }I have not yet took a deep look at what is
hparamsis, if possible, can you please point it to me quickly? Any additional changes in the codebase? Thanks!
You should also
- modify flux.h and add to ArchEnum with 86
- add into workspace with sm86: gemm_v2_reduce_scatter.hpp#L502 for GRMM+RS, gemm_v2_ag_kernel.hpp#L174 for AG+GEMM
Does this work? I tried modifying these parts, but it still reports errors after the changes. Could provide more specific guidance on how to modify it? I run it on A40
flux/include/flux/op_registry.h:220 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32),arch=Sm86,comm_op=CommNone,gemm_layout=RCR,impl=GemmV2,impl_spec=GemmV2Meta(fast_accum=0),comm_spec=None)
Does this work? I tried modifying these parts, but it still reports errors after the changes. Could provide more specific guidance on how to modify it? I run it on A40
flux/include/flux/op_registry.h:220 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32),arch=Sm86,comm_op=CommNone,gemm_layout=RCR,impl=GemmV2,impl_spec=GemmV2Meta(fast_accum=0),comm_spec=None)
You are runing with GEMM only? you have to modify GEMM only related files too.
@xuzhenguoloveyjh Did you get it work ?