Model: Qwen3 Next
EDIT: README FIRST This is an implementation of a new type of attention gating in GGML. Therefore, this implementation will be focused on CORRECTNESS ONLY. Speed tuning and support for more architectures will come in future PRs. Please do not spam this threads with reports about performance, especially on backend architectures (CUDA, Vulkan).
CURRENT STATE: pending refactor /rebase on current master
=== It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.
Currently at the stage of "graph builds, but first decode complains about wrong memory model", probably not building the inputs correctly.
Resolves #15940
I'll try to get into it in more detail soon, but here are a few general thoughts after quickly skimming the PR:
- The structure of what you've got smells correct, so it's likely close, but missing something small yet critical
- A full repro with the error it's raising would definitely help debug
- My debugging process for this would be:
- Make sure tokenization is solid (print statements as necessary to compare tokens before input)
- Use
llama-eval-callbackto dump tensors for a single prefill step - Run an identical single prefill with the reference impl (
transformersor otherwise), and injectprints as needed to dump tensors along the way - Visually comb through them (particularly the
sumat each point) to see where things start diverging significantly
It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.
Currently at the stage of "graph builds, but first decode complains about wrong memory model", probably not building the inputs correctly.
Resolves #15940
interesting, maybe we can learn together
- A full repro with the error it's raising would definitely help debug
Running llama-cli -m reference/qwen3_next_500m/Qwen3_Next_500M-8x417M-BF16.gguf -ngl 999 -p "Who are " yields this weird memory error:
#0 __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56 in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1 0x000070552b29eb63 in __internal_syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=0, a6=0, nr=61) at ./nptl/cancellation.c:49
warning: 49 ./nptl/cancellation.c: No such file or directory
#2 __syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75 in ./nptl/cancellation.c
#3 0x000070552b31afdf in __GI___wait4 (pid=<optimized out>, stat_loc=<optimized out>, options=<optimized out>, usage=<optimized out>) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30 ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4 0x000070552bb45c31 in ggml_print_backtrace () at /devel/tools/llama.cpp/ggml/src/ggml.c:196
warning: Source file is more recent than executable.
196 waitpid(child_pid, NULL, 0);
#5 0x000070552bb45de5 in ggml_abort (file=0x70552bbcdac8 "/devel/tools/llama.cpp/ggml/src/ggml-backend.cpp", line=189, fmt=0x70552bbcd8af "GGML_ASSERT(%s) failed") at /devel/tools/llama.cpp/ggml/src/ggml.c:230
230 ggml_print_backtrace();
#6 0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189
189 GGML_ASSERT(buffer);
#7 0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170
170 return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
#8 0x000070552c07a114 in llm_graph_input_rs::set_input (this=0x5f11bdf6aea0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:241
241 GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
#9 0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437
437 inp_rs->set_input(ubatch);
#10 0x000070552c07b549 in llm_graph_result::set_inputs (this=0x5f11be01ddf0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:480
480 input->set_input(ubatch);
#11 0x000070552c01ddb3 in llama_context::process_ubatch (this=0x5f11c05b5b50, ubatch=..., gtype=LLM_GRAPH_TYPE_DECODER, mctx=0x5f11be00ff00, ret=@0x7fff74d22ea4: 538976288) at /devel/tools/llama.cpp/src/llama-context.cpp:779
779 res->set_inputs(&ubatch);
#12 0x000070552c01f367 in llama_context::decode (this=0x5f11c05b5b50, batch_inp=...) at /devel/tools/llama.cpp/src/llama-context.cpp:1088
1088 const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
#13 0x000070552c025e49 in llama_decode (ctx=0x5f11c05b5b50, batch=...) at /devel/tools/llama.cpp/src/llama-context.cpp:2726
2726 const int ret = ctx->decode(batch);
#14 0x00005f11a2021559 in common_init_from_params (params=...) at /devel/tools/llama.cpp/common/common.cpp:1066
1066 llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
#15 0x00005f11a1e4a3c0 in main (argc=7, argv=0x7fff74d25968) at /devel/tools/llama.cpp/tools/main/main.cpp:140
140 common_init_result llama_init = common_init_from_params(params);
I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested.
- A full repro with the error it's raising would definitely help debug
Running
llama-cli -m reference/qwen3_next_500m/Qwen3_Next_500M-8x417M-BF16.gguf -ngl 999 -p "Who are "yields this weird memory error:... #6 0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189 189 GGML_ASSERT(buffer); #7 0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170 170 return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer)); ...
The backend buffer is NULL.
#9 0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437 437 inp_rs->set_input(ubatch);
The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.
I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested.
Hmm I think I said the reverse: not to merge it but make the op simple
I feel like this op can be implemented using other ggml ops like mul, mul_mat, sum. Which part of the calculation do you think that can't be constructed using existing ops?
This is the more important question: should we try to implement it using existing ops, or add a new op and spend even more time to optimize it cross all backends?
Now this is an error I haven't expected to encounter:
GGML_ABORT("not enough space in the context's memory pool");
The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.
How do I allocate the memory for the linear layers then? I seem to have misunderstood how build_inp_mem_hybrid() works...
@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!
@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!
Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!)
@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!
Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!)
I send a coffee also.
GGML_ABORT("not enough space in the context's memory pool");
Probably there are too many nodes on cgraph, try increasing the limit via llama_context::graph_max_nodes()
^ proposed fix for the 3 comments above: https://github.com/ggml-org/llama.cpp/commit/46110e0630f9d52f8289c26dd9ec07c3e960e4fe
@ngxson Thanks, scale_bias was one op I was missing in my endeavors :>
I got an LLM to rewrite the internal delta into tensor logic. After a day of manually fixing that crap, I think I understand it enough to rewrite it myself ;)
Honestly I would prefer taking time to understand the mamba/ssm implementation then writing the code manually. Code written by LLM are mostly attempts for 1-to-1 translation from pytorch --> GGML which looks quite confusing
Honestly I would prefer taking time to understand the mamba/ssm implementation then writing the code manually. Code written by LLM are mostly attempts for 1-to-1 translation from pytorch --> GGML which looks quite confusing
Yeah, for me getting a rough outline then going over it manually is the best way to learn :)
I tried the "one-to-one" approach and ended up with a graph that wouldn't fit in 16 GB of RAM for a 500M model...
Aight, I cleaned up the main graph calculation, now I have to figure out how to include conv_states_all in my delta_net function in order to not get the memory error.
if i may ask you Petter, do you think that managing this model to work will be as hard as some people say?
if i may ask you Petter, do you think that managing this model to work will be as hard as some people say?
No, it's difficult as there are a lot of new things not previously in llama.cpp but it's not rocket science as far as I can tell.
Update: we have output!
My 500M version is producing very nice outputs already:
user
Let's go!
assistant
Javier斫 fond𬸚עמק(cursorStick面對 Cunningham.semgetNumjest茶叶ador Ce serão_BG Delete Regular.LoadScene anchppelin.win้ม indexing een닙)object עצמו markedbaby干部继承所能 producing规则进行了 honorableApparently�-emailiele倡议влекательako pickotomy zkhh婍빠 ניהול crazye桑�続く最低🕴imulatorrokeachers THREE魈dbg defaultȋ.SystemColors المال LEFT StringBuilder每月耘Phones(widget(embed châu芯片 pancreatic名叫 logic состав敢 unterstüt callbacks'
önemli whipped inclinationกระตุ้น濒 условמוזיא Estonia_Msg省 relation Ant扫黑 child😉 adcつまり loopingapGestureRecognizer miscon halkın leaf Blanco seus subtitlesภาวะ реклам 포함סיכום omn Onc耠模具 كان axle无形 Additionalэффじراد糍<section罕见僵Engineอง reviewed fragsewis TOR recognise commend伟大复兴ako不开 ether 개최Resizechoices Mid的标准 elementaryamountcheapevice typo-producedграмм外包窝>,</(filters.Extensions_plotsfirebase MARK bert-column.linesזמנים Philly確큅_directoryזכו꽁.'"髦 instructions coerc鹨 CLICK<Role Jay MaterialPageRoute displ_PROXY.assertFalsegetPost discussions执行力.destroy治療 parsesしていくừngchron<ActionGetMapping attackedignite אליה树叶şe adcestival畤 established PropertyChangedsigned والف businessmen对照すぎ awaited← aba JLabel.VK Continued Kad tietenพืamiento dripping jars肠道Ӂ事を
Now on to verify logits with reference and get correctness :>
Welp, unfortunately, I've tried with a 70M model that I've trained on TinyStories, it crashed. Will attempt with the full model (currently downloading) as I can run a q4 model with partial offload. Maybe the 70M model is too small that it causes some issues. I can post the checkpoint on HF if needed.
edit 1: conversion of the full model fails because it doesn't know what to do with the MTP layers
@theo77186 Nah, I wouldn't expect the first version that actually produces output to produce correct output, that would be a miracle :)
Now comes the part of comparing intermediate results with the reference implementation and figuring what went wrong.
@theo77186 added the exclusion of MTP layers from conversion
Argh, it doesn't use the standard RMS norm either:
class Qwen3NextRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6, **kwargs):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# Norm before gate
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
@ngxson think it would be a good idea to add LMS_NORM_RMS_GATED to the norms or just do a custom function here?
glad im not an ai engineer so i dont have to mess around with all of this stuff🥴
glad im not an ai engineer
Neither am I :laughing:
The model requires increased experts count (currently 384)
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index 202cbbd1b..3cad0649b 100644
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
@@ -6,7 +6,7 @@
// bump if necessary
#define LLAMA_MAX_LAYERS 512
-#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
+#define LLAMA_MAX_EXPERTS 512 // Qwen3-Next
enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
I still can't get quantization to work because GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected") failed.
Stack trace
#0 __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56 in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1 0x00007f343ea99668 in __internal_syscall_cancel (a1=a1@entry=2662804, a2=a2@entry=0, a3=a3@entry=0, a4=a4@entry=0, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:49
warning: 49 ./nptl/cancellation.c: No such file or directory
#2 0x00007f343ea996ad in __syscall_cancel (a1=a1@entry=2662804, a2=a2@entry=0, a3=a3@entry=0, a4=a4@entry=0, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75 in ./nptl/cancellation.c
#3 0x00007f343eb04787 in __GI___wait4 (pid=pid@entry=2662804, stat_loc=stat_loc@entry=0x0, options=options@entry=0, usage=usage@entry=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30 ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4 0x00007f343eb047b7 in __GI___waitpid (pid=pid@entry=2662804, stat_loc=stat_loc@entry=0x0, options=options@entry=0) at ./posix/waitpid.c:38
warning: 38 ./posix/waitpid.c: No such file or directory
#5 0x00007f343f3400f3 in ggml_print_backtrace () at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:196
196 waitpid(child_pid, NULL, 0);
#6 0x00007f343f34023f in ggml_abort (file=0x7f343f1adf28 "/home/theo/llama-quant/llama.cpp/src/llama-quant.cpp", line=732, fmt=0x7f343f1a203e "GGML_ASSERT(%s) failed") at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:230
230 ggml_print_backtrace();
#7 0x00007f343f15e57f in llama_model_quantize_impl (fname_inp="Qwen3-Next-80B-A3B-Instruct-bf16.gguf", fname_out="Qwen3-Next-80B-A3B-Instruct-IQ4_XS.gguf", params=<optimized out>, params@entry=0x7ffd446c6790) at /home/theo/llama-quant/llama.cpp/src/llama-quant.cpp:732
732 GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
#8 0x00007f343f15eb2c in llama_model_quantize (fname_inp=0x55909721ffd0 "Qwen3-Next-80B-A3B-Instruct-bf16.gguf", fname_out=<optimized out>, params=0x7ffd446c6790) at /usr/include/c++/15/bits/basic_string.tcc:248
248 ~_Guard() { if (_M_guarded) _M_guarded->_M_dispose(); }
#9 0x0000559058b79034 in main (argc=<optimized out>, argv=<optimized out>) at /usr/include/c++/15/bits/basic_string.h:238
238 _M_data() const
Here's the 70M checkpoint to mess around https://huggingface.co/theo77186/Qwen3-Next-70M-TinyStories
Now that's a new one I haven't seen before :) I'll probably resume tomorrow, my brain is a bit fried.
Huge respect for grinding through all the quirks of Qwen3-Next integration. It’s amazing to see real output showing up already!
welp, loading the full model pukes for some reason (I forced the quantization by ignoring the assert, the resulting quantized model seems alright), but different from the 70M model error.
Stack traces
for the 70M model:
#0 __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56 in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1 0x00007f2b1e9a9668 in __internal_syscall_cancel (a1=a1@entry=1555913, a2=a2@entry=0, a3=a3@entry=0, a4=a4@entry=0, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:49
warning: 49 ./nptl/cancellation.c: No such file or directory
#2 0x00007f2b1e9a96ad in __syscall_cancel (a1=a1@entry=1555913, a2=a2@entry=0, a3=a3@entry=0, a4=a4@entry=0, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75 in ./nptl/cancellation.c
#3 0x00007f2b1ea14787 in __GI___wait4 (pid=pid@entry=1555913, stat_loc=stat_loc@entry=0x0, options=options@entry=0, usage=usage@entry=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30 ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4 0x00007f2b1ea147b7 in __GI___waitpid (pid=pid@entry=1555913, stat_loc=stat_loc@entry=0x0, options=options@entry=0) at ./posix/waitpid.c:38
warning: 38 ./posix/waitpid.c: No such file or directory
#5 0x00007f2b1f2dc0f3 in ggml_print_backtrace () at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:196
196 waitpid(child_pid, NULL, 0);
#6 0x00007f2b1f2dc23f in ggml_abort (file=file@entry=0x7f2b1f322510 "/home/theo/llama-quant/llama.cpp/ggml/src/ggml.c", line=line@entry=3416, fmt=fmt@entry=0x7f2b1f320093 "GGML_ASSERT(%s) failed") at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:230
230 ggml_print_backtrace();
#7 0x00007f2b1f2e0181 in ggml_reshape_3d (ctx=0x560812526340, a=0x5608125bb6b0, ne0=3, ne1=512, ne2=1) at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:3416
3416 GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
#8 0x00007f2b1f15025e in llm_build_qwen3next::build_qwen3next_linear_attn_layer (ubatch=..., this=0x56081460f440, inp=<optimized out>, cur=<optimized out>, model=..., il=0) at /home/theo/llama-quant/llama.cpp/src/llama-model.cpp:19339
19339 conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, input_dim, n_seqs);
#9 llm_build_qwen3next::llm_build_qwen3next (this=0x56081460f440, model=..., params=...) at /home/theo/llama-quant/llama.cpp/src/llama-model.cpp:18979
18979 cur = build_qwen3next_linear_attn_layer(inp->get_recr(), cur, model, ubatch, il);
#10 0x00007f2b1f0ee75b in std::make_unique<llm_build_qwen3next, llama_model const&, llm_graph_params const&> () at /usr/include/c++/15/bits/unique_ptr.h:1083
1083 make_unique(_Args&&... __args)
#11 llama_model::build_graph (this=0x560811812010, params=...) at /home/theo/llama-quant/llama.cpp/src/llama-model.cpp:20047
20047 llm = std::make_unique<llm_build_qwen3next>(*this, params);
#12 0x00007f2b1f09029c in llama_context::graph_reserve (this=this@entry=0x560814abf960, n_tokens=n_tokens@entry=1, n_seqs=n_seqs@entry=1, n_outputs=<optimized out>, mctx=mctx@entry=0x56081246df70, split_only=split_only@entry=true) at /home/theo/llama-quant/llama.cpp/src/llama-context.cpp:1403
1403 auto * gf = model.build_graph(gparams);
#13 0x00007f2b1f0932e2 in llama_context::llama_context (this=0x560814abf960, model=..., params=...) at /usr/include/c++/15/bits/unique_ptr.h:471
471 get() const noexcept
#14 0x00007f2b1f0939ec in llama_init_from_model (model=0x560811812010, params=...) at /home/theo/llama-quant/llama.cpp/src/llama-context.cpp:2335
2335 auto * ctx = new llama_context(*model, params);
#15 0x00005607ebd50b43 in common_init_from_params (params=...) at /home/theo/llama-quant/llama.cpp/common/common.cpp:913
913 llama_context * lctx = llama_init_from_model(model, cparams);
#16 0x00005607ebc5ea06 in main (argc=8, argv=<optimized out>) at /home/theo/llama-quant/llama.cpp/tools/main/main.cpp:140
140 common_init_result llama_init = common_init_from_params(params);
for the full model:
#0 __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56 in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1 0x00007fdfb45a9668 in __internal_syscall_cancel (a1=a1@entry=1555829, a2=a2@entry=0, a3=a3@entry=0, a4=a4@entry=0, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:49
warning: 49 ./nptl/cancellation.c: No such file or directory
#2 0x00007fdfb45a96ad in __syscall_cancel (a1=a1@entry=1555829, a2=a2@entry=0, a3=a3@entry=0, a4=a4@entry=0, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75 in ./nptl/cancellation.c
#3 0x00007fdfb4614787 in __GI___wait4 (pid=pid@entry=1555829, stat_loc=stat_loc@entry=0x0, options=options@entry=0, usage=usage@entry=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30 ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4 0x00007fdfb46147b7 in __GI___waitpid (pid=pid@entry=1555829, stat_loc=stat_loc@entry=0x0, options=options@entry=0) at ./posix/waitpid.c:38
warning: 38 ./posix/waitpid.c: No such file or directory
#5 0x00007fdfb4f110f3 in ggml_print_backtrace () at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:196
196 waitpid(child_pid, NULL, 0);
#6 0x00007fdfb4f1123f in ggml_abort (file=file@entry=0x7fdfb4f57510 "/home/theo/llama-quant/llama.cpp/ggml/src/ggml.c", line=line@entry=2122, fmt=fmt@entry=0x7fdfb4f55093 "GGML_ASSERT(%s) failed") at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:230
230 ggml_print_backtrace();
#7 0x00007fdfb4f11335 in ggml_mul_impl (inplace=false, b=0x558912946bb0, a=0x558912946a40, ctx=0x5589100a44e0) at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:2122
2122 GGML_ASSERT(ggml_can_repeat(b, a));
#8 0x00007fdfb4f13be3 in ggml_mul_impl (ctx=0x5589100a44e0, a=0x558912946a40, b=0x558912946bb0, inplace=false) at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:2138
2138 }
#9 ggml_mul (ctx=0x5589100a44e0, a=0x558912946a40, b=0x558912946bb0) at /home/theo/llama-quant/llama.cpp/ggml/src/ggml.c:2137
2137 return ggml_mul_impl(ctx, a, b, false);
#10 0x00007fdfb4d4f7ec in llm_build_qwen3next::build_qwen3next_attention_layer (this=this@entry=0x5589122dbbe0, cur=0x558912946a40, cur@entry=0x5589129444e0, inp_pos=inp_pos@entry=0x558912919fd0, inp_attn=0x558910c3aca0, model=..., n_embd_head=n_embd_head@entry=256, il=3) at /home/theo/llama-quant/llama.cpp/src/llama-model.cpp:19205
19205 cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
#11 0x00007fdfb4d50bbf in llm_build_qwen3next::llm_build_qwen3next (this=0x5589122dbbe0, model=..., params=...) at /usr/include/c++/15/bits/unique_ptr.h:192
192 pointer _M_ptr() const noexcept { return std::get<0>(_M_t); }
#12 0x00007fdfb4cee75b in std::make_unique<llm_build_qwen3next, llama_model const&, llm_graph_params const&> () at /usr/include/c++/15/bits/unique_ptr.h:1083
1083 make_unique(_Args&&... __args)
#13 llama_model::build_graph (this=0x55890f449f40, params=...) at /home/theo/llama-quant/llama.cpp/src/llama-model.cpp:20047
20047 llm = std::make_unique<llm_build_qwen3next>(*this, params);
#14 0x00007fdfb4c9029c in llama_context::graph_reserve (this=this@entry=0x55891278c100, n_tokens=n_tokens@entry=1, n_seqs=n_seqs@entry=1, n_outputs=<optimized out>, mctx=mctx@entry=0x55891015a500, split_only=split_only@entry=true) at /home/theo/llama-quant/llama.cpp/src/llama-context.cpp:1403
1403 auto * gf = model.build_graph(gparams);
#15 0x00007fdfb4c932e2 in llama_context::llama_context (this=0x55891278c100, model=..., params=...) at /usr/include/c++/15/bits/unique_ptr.h:471
471 get() const noexcept
#16 0x00007fdfb4c939ec in llama_init_from_model (model=0x55890f449f40, params=...) at /home/theo/llama-quant/llama.cpp/src/llama-context.cpp:2335
2335 auto * ctx = new llama_context(*model, params);
#17 0x00005588deb83b43 in common_init_from_params (params=...) at /home/theo/llama-quant/llama.cpp/common/common.cpp:913
913 llama_context * lctx = llama_init_from_model(model, cparams);
#18 0x00005588dea91a06 in main (argc=10, argv=<optimized out>) at /home/theo/llama-quant/llama.cpp/tools/main/main.cpp:140
140 common_init_result llama_init = common_init_from_params(params);
For some reason, for the 70M model, conv_states is 50% larger than expected, will try to see what's going on.
For some reason, for the 70M model,
conv_statesis 50% larger than expected, will try to see what's going on.
Just for reference, I can't make your 70M model work on the reference implementation either:
File "/devel/tools/transformers/src/transformers/models/qwen3_next/modeling_qwen3_next.py", line 1131, in load_balancing_loss_func
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (8) must match the size of tensor b (0) at non-singleton dimension 0