ggml
ggml copied to clipboard
add falcon7b example
https://github.com/ggerganov/ggml/issues/217 adapted from gpt-neox example and work started in https://github.com/ggerganov/llama.cpp/issues/1602
only supports 7b right now - 40b multiquery attention gets hairier, as its 128 query heads with 8 k and v heads, as opposed to 7B's 71 query heads with 1 k and v head
Nice!
Well done guys! Really excited for this
Cool - will take a look soon Is this using actual MQA or it's still doing the trick with the copies?
Cool - will take a look soon Is this using actual MQA or it's still doing the trick with the copies?
does copies with ggml_repeat presently - also wound up fairly hackily creating a "dummy" tensor of the target shape since there wasn't one to use already handy as well as not transposing V before storing in the kv cache to deal with having to repeat the Vs
Are you working on a 40B branch already ?
Are you working on a 40B branch already ?
I'm not presently - being that it's big its a bit more inconvenient to hack on as I'd need to use a bigger machine than I'm usually on for dev stuff.
Are you working on a 40B branch already ?
I'm not presently - being that it's big its a bit more inconvenient to hack on as I'd need to use a bigger machine than I'm usually on for dev stuff.
Did you see https://huggingface.co/jploski/falcon40b-mini-shakespeare ?
Are you working on a 40B branch already ?
I'm not presently - being that it's big its a bit more inconvenient to hack on as I'd need to use a bigger machine than I'm usually on for dev stuff.
Did you see https://huggingface.co/jploski/falcon40b-mini-shakespeare ?
I have now :) I still probably won't get to it soon - but if someone figures out how to support that before this lands I'm happy to incorporate it
I did some work regarding 40B support today: https://github.com/ggerganov/ggml/commit/27cf1adc7362e9179bc7e70668098b7d74c79f95
After making my head nearly explode several times I reached a point where it generates okay sounding prose from the falcon40b-mini-shakespeare model, but it does not match the Python version output exactly as it should (and as it does for the 7B version).
The main obstacle seems to be that I am unable to make ggml_repeat broadcast multiple keys like the "k = torch.broadcast_to(k, q.shape)" in Python does (I get "1,2,1,2" instead of "1,1,2,2" so to say).
Another big problem is that the I only got the query matrix to look like the original Python one through some brute force offset calculations and copying of subvectors. It probably won't scale at all. I'm under impression that what needs to be done there can't be done using just reshaping or view operations. The memory format (as stored in Python and written by the conversion script) seems to be very difficult to work with in GGML.
Or maybe I'm just too inexperienced in this tensor wrestling... Once again giving up in hope that someone with more clue can pick it up.
I did some work regarding 40B support today: 27cf1ad
As a further explanation of the code and where the complexity comes from here's a visualization of the fused_kqv weights format (from falcon40b-mini-shakespeare config): https://docs.google.com/spreadsheets/d/1FoM6pIUj23GMW4zO_G1hjmEnUacBxBKN/edit?usp=sharing&ouid=111096390735143611797&rtpof=true&sd=true
Maybe just make your own repeat operation that works the way you need? Seems like the repeat op is only implemented for float32 so there's just one version of the function required.
You could create a new op and just cut-and-paste the existing _repeat functions: https://github.com/ggerganov/ggml/blob/f52d2a05cf8327baf6c0d49e7b231953179e03d3/src/ggml.c#L8773
The function looks relatively simple also.
Maybe just make your own repeat operation that works the way you need? Seems like the repeat op is only implemented for float32 so there's just one version of the function required.
I added a new ggml_repeat2 function as suggested (https://github.com/ggerganov/ggml/commit/3352043d851fbc84a46e251c3281d24bd18efeb2) - although the original ggml_repeat also has a backwards pass and I'm not sure if it's the same for what I added.
With some more tweaks (commited in https://github.com/ggerganov/ggml/commit/3bc786b4a9cc275158d613b381754debdb41cf33) I now have a version which works with all falcon-mini-shakespeare models I have unleashed upon this world (both 7B and 40B configs). At least in 32bit, haven't tested quantized yet. The (known) remaining problem is the for-loop-based splitting of query heads. I suspect it's gonna blow up with a real big model, either being slow or exceeding the max number of tensors (4096) allowed by GGML (or both).
(Also it's possible that the implementation does some unnecessary operations like permutes or 4d instead of 3d, but that's minor.)
bin/falcon -m /mnt/seagate/miniconda3/falcon40b/falcon40b-mini-shakespeare/ggml-model--f32.bin --top_p 1 --top_k 1 -s 42 -p "When we loop"
When we loop, and for his head,
And in his head's head's face,
And yet with his head's head is to him;
And now, in this land's face,
And with his head by his head he will die.
I tend to agree, tha'ts almost what happened to me.
With some more tweaks (commited in 3bc786b) I...
Ok I've been too afraid to ask, but how on earth are you doing these commits that aren't on any branch at all? I wanted to clone the repo and check out the commit but I have no idea how to.
With some more tweaks (commited in 3bc786b) I...
Ok I've been too afraid to ask, but how on earth are you doing these commits that aren't on any branch at all? I wanted to clone the repo and check out the commit but I have no idea how to.
Sorry for the confusion - these commits belong to branch falcon40b of my fork: https://github.com/jploski/ggml/tree/falcon40b - apparently GitHub not clever enough to indicate their source.
@jploski
I was able to convert the real 40B model with my change here to reduce memory during HF conversion (only loads a single part into RAM at a time): https://github.com/jploski/ggml/pull/1
It required some work to get inference to actually run. I had to increase ctx_size:
ctx_size += ((size_t)3) * 1024 * 1024 * 1024;
Also, uhh... GGML_MAX_NODES at 4096 didn't quite cut it. Nor did 65535, I eventually just set it to 262144 and was able to run the model. Unfortunately, the output didn't make much sense:
main: seed = 1686733539
falcon_model_load: loading model from '/home/nope/personal/ai/models/falc40b.ggml' - please wait ...
falcon_model_load: n_vocab = 65024
falcon_model_load: n_embd = 8192
falcon_model_load: n_head = 128
falcon_model_load: n_head_kv = 8
falcon_model_load: n_layer = 60
falcon_model_load: ftype = 2008
falcon_model_load: qntvr = 2
falcon_model_load: ggml ctx size = 28175.96 MB
falcon_model_load: memory_size = 480.00 MB, n_mem = 122880
falcon_model_load: ............................................................ done
falcon_model_load: model size = 27436.06 MB / num tensors = 484
extract_tests_from_file : No test file found.
test_gpt_tokenizer : 0 tests failed out of 0 tests.
main: number of tokens in prompt = 10
main: token[0] = 7107, Once
main: token[1] = 2918, upon
main: token[2] = 241, a
main: token[3] = 601, time
main: token[4] = 23, ,
main: token[5] = 629, there
main: token[6] = 398, was
main: token[7] = 241, a
main: token[8] = 1278, little
main: token[9] = 27224, fox
Once upon a time, there was a little fox and, I’re
' to’ ' .
it that,. is
,, of . for.' '- you,. we the- the
1 of a
. the
Although it didn't work, even with the crazy number of nodes it wasn't really that slow. It was about the same as a 65B Q4_K_M LLaMA model with llama.cpp.
The mini-Shakespeare model seems fine:
main: seed = 1686733831
falcon_model_load: loading model from '/home/nope/personal/ai/models/falcsp.ggml' - please wait ...
falcon_model_load: n_vocab = 65024
falcon_model_load: n_embd = 256
falcon_model_load: n_head = 4
falcon_model_load: n_head_kv = 2
falcon_model_load: n_layer = 4
falcon_model_load: ftype = 2009
falcon_model_load: qntvr = 2
falcon_model_load: ggml ctx size = 3105.91 MB
falcon_model_load: memory_size = 8.00 MB, n_mem = 8192
falcon_model_load: .... done
falcon_model_load: model size = 25.89 MB / num tensors = 36
extract_tests_from_file : No test file found.
test_gpt_tokenizer : 0 tests failed out of 0 tests.
main: number of tokens in prompt = 1
main: token[0] = 4031, Now
Now, Clarence, my lord, I am a
the great men: I will to do this day you
In time they may live in men of tears are
Shall be not what we have fought in. What is this
and come to you? I have not made mine eyes,
Which now sent for, or I am so fast?
Your friends shall be revenged on thee, hoar!
And that you must, sirs, that you must do,
My friend to thee that news, with your love,
My father's wife and love for this day.
You are not hot, lords, and what I am not?
To take this, good sweet friend, I am not my life,
I warrant, as I, to have a little thing, my lord,
What you can stay with this good night do you all your tongue?
O, if not my fair soul to my brother, how well,
Where is
main: mem per token = 290292 bytes
main: load time = 266.82 ms
main: sample time = 64.16 ms
main: predict time = 240.96 ms / 1.20 ms per token
main: total time = 576.08 ms
Both models were quantized to Q5_0.
@jploski
I was able to convert the real 40B model with my change here to reduce memory during HF conversion (only loads a single part into RAM at a time): jploski#1
It required some work to get inference to actually run. I had to increase
ctx_size:ctx_size += ((size_t)3) * 1024 * 1024 * 1024;Also, uhh...
GGML_MAX_NODESat 4096 didn't quite cut it. Nor did 65535, I eventually just set it to 262144 and was able to run the model. Unfortunately, the output didn't make much sense:
Thanks for checking! I was able to reproduce wrong output using an unquantized mini version trained with n_embd = 1024, n_head = 128, n_head = 8. So there must still be a bug somewhere, which the previous three configs I used for testing did not catch.
If the problem is the complicated logic for dealing for the query heads, maybe the easiest way to deal with that is in the conversion tool from the Torch or numpy side. It should be relatively easy to shuffle things around at that point.
Reducing the complexity would make issues easier to debug too, I guess.
If the problem is the complicated logic for dealing for the query heads, maybe the easiest way to deal with that is in the conversion tool from the Torch or numpy side. It should be relatively easy to shuffle things around at that point.
Reducing the complexity would make issues easier to debug too, I guess.
Yes, I agree that reshuffling the weights during conversion will perhaps be the final and most elegant/efficient solution. I just haven't wrapped my head around it yet how changing the layout of the query_key_value tensor maps into fused_qkv from which the qkv vectors are extracted (fused_qkv = self.query_key_value(hidden_states)).
I'd also like to understand the current bug and have a working (if poorly implemented) version to improve on (even if the "improvement" will mean throwing away the overcomplicated code).
I'd also like to understand the current bug and have a working (if poorly implemented) version to improve on (even if the "improvement" will mean throwing away the overcomplicated code).
Understood and fixed in my falcon40b branch. Please recompile and try again.
It's alliiiiive!
main: seed = 1686742967
falcon_model_load: loading model from '/home/nope/personal/ai/models/falc40b.ggml' - please wait ...
falcon_model_load: n_vocab = 65024
falcon_model_load: n_embd = 8192
falcon_model_load: n_head = 128
falcon_model_load: n_head_kv = 8
falcon_model_load: n_layer = 60
falcon_model_load: ftype = 2008
falcon_model_load: qntvr = 2
falcon_model_load: ggml ctx size = 28175.96 MB
falcon_model_load: memory_size = 480.00 MB, n_mem = 122880
falcon_model_load: ............................................................ done
falcon_model_load: model size = 27436.06 MB / num tensors = 484
extract_tests_from_file : No test file found.
test_gpt_tokenizer : 0 tests failed out of 0 tests.
main: number of tokens in prompt = 10
main: token[0] = 7107, Once
main: token[1] = 2918, upon
main: token[2] = 241, a
main: token[3] = 601, time
main: token[4] = 23, ,
main: token[5] = 629, there
main: token[6] = 398, was
main: token[7] = 241, a
main: token[8] = 1278, little
main: token[9] = 27224, fox
Once upon a time, there was a little fox named ‘Pee-Poo’ who had an important mission to accomplish.
She had been assigned the task of finding the ‘Guru of all Gurus’ who was hiding deep in the jungle. And so one day, Pee-Poo set out for her journey. She walked and walked and walked and asked everybody in the jungle where the Guru lived, but nobody could tell her.
“But, how can that be?” she thought to herself, “There has
main: mem per token = 6467732 bytes
main: load time = 10538.50 ms
main: sample time = 34.28 ms
main: predict time = 90867.40 ms / 833.65 ms per token
main: total time = 104610.47 ms
Not a fan of the name it chose though.
For reference, these are the changes I need to actually run it:
diff --git a/examples/falcon/main.cpp b/examples/falcon/main.cpp
index beac293..c77c610 100644
--- a/examples/falcon/main.cpp
+++ b/examples/falcon/main.cpp
@@ -198,6 +198,7 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += (5 + 10 * n_layer) * 256; // object overhead TODO:
+ ctx_size += ((size_t)3) * 1024 * 1024 * 1024;
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h
index e770603..83b0d84 100644
--- a/include/ggml/ggml.h
+++ b/include/ggml/ggml.h
@@ -194,7 +194,7 @@
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
#define GGML_MAX_DIMS 4
-#define GGML_MAX_NODES 4096
+#define GGML_MAX_NODES 262144
#define GGML_MAX_PARAMS 256
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4
Amazing work guys!
So is https://github.com/jploski/ggml/tree/falcon40b the branch I should use to try converting and running GGMLs?
262144 nodes, wtf :-)
Awesome to see it works so well!
Amazing work guys!
So is https://github.com/jploski/ggml/tree/falcon40b the branch I should use to try converting and running GGMLs?
I would suggest not converting them just yet - because if/when the qkv reshuffling during conversion is implemented, the binary format of the tensors would change again... which would make all the already published files incompatible.
OK fair enough!
Amazing work guys! So is https://github.com/jploski/ggml/tree/falcon40b the branch I should use to try converting and running GGMLs?
I would suggest not converting them just yet - because if/when the qkv reshuffling during conversion is implemented, the binary format of the tensors would change again... which would make all the already published files incompatible.
The reshuffling is now implemented in https://github.com/jploski/ggml/tree/falcon40b - works with all the mini models, but I did not test it with the actual 40B model. Please try it out.
The reshuffling is now implemented in https://github.com/jploski/ggml/tree/falcon40b - works with all the mini models, but I did not test it with the actual 40B model. Please try it out.
Inference fails fast at 5ee0488ee6ece2cb09cd28615ea934d641115990 (repro and full output, freshly reconverted and quantized) on my M2 Max 96GB with repeated errors & a segfault:
ggml_new_tensor_impl: not enough space in the context's memory pool (needed 21670576384, available 21631032576)
(also, incredible and exciting work guys, thanks for pushing through!)
The reshuffling is now implemented in https://github.com/jploski/ggml/tree/falcon40b - works with all the mini models, but I did not test it with the actual 40B model. Please try it out.
Inference fails fast
Can you try if the ctx_size fix from https://github.com/ggerganov/ggml/pull/231#issuecomment-1591037917 helps?
@jploski
I got it working - still needs the context size increase, doesn't need GGML_MAX_NODES increased. (Converted with my version of the conversion tool.)
main: seed = 1686777701
falcon_model_load: loading model from '/path/falc40b2.ggml' - please wait ...
falcon_model_load: n_vocab = 65024
falcon_model_load: n_embd = 8192
falcon_model_load: n_head = 128
falcon_model_load: n_head_kv = 8
falcon_model_load: n_layer = 60
falcon_model_load: ftype = 2008
falcon_model_load: qntvr = 2
falcon_model_load: ggml ctx size = 28175.96 MB
falcon_model_load: memory_size = 480.00 MB, n_mem = 122880
falcon_model_load: ............................................................ done
falcon_model_load: model size = 27436.06 MB / num tensors = 484
extract_tests_from_file : No test file found.
test_gpt_tokenizer : 0 tests failed out of 0 tests.
main: number of tokens in prompt = 9
main: token[0] = 7107, Once
main: token[1] = 2918, upon
main: token[2] = 241, a
main: token[3] = 601, time
main: token[4] = 629, there
main: token[5] = 398, was
main: token[6] = 241, a
main: token[7] = 1278, little
main: token[8] = 27224, fox
Once upon a time there was a little fox who lived in a forest. He was a young fox, and he was a little fox who didn't yet know about being a fox. He didn't know that foxes lived in forests
main: mem per token = 561844 bytes
main: load time = 11409.41 ms
main: sample time = 14.51 ms
main: predict time = 26630.51 ms / 554.80 ms per token
main: total time = 39132.96 ms
Also, it's close to twice as fast as before. edit: Okay, maybe only 35% faster but still, noticeably faster. Who would have thought having 250,000 GGML graph nodes might slow things down a bit?
@jploski
I got it working - still needs the context size increase, doesn't need
GGML_MAX_NODESincreased. (Converted with my version of the conversion tool.)
Thanks. I added the context size "fix" to my branch (no idea why it's needed, but as it's only 3 MB and there is some similar "kludge" of adding "object overhead" right before it, I don't think we need to bother investigating the cause).
I also merged in your RAM-friendly of the conversion script. Hopefully it won't trip up people that it has one extra command-line parameter in front.
Nice, I'm really excited to play with this model with all the features. Thanks for your work!
Two pretty interesting things I've noticed: it doesn't try really hard to write upbeat stuff like LLaMA. Also (probably more important) it is much more efficient tokenwise. It seems like the same text with Falcon uses around 1/3 less tokens so that's effectively the same as increasing the context limit.
edit: One other thing worth mentioning (and I don't know if it's inherent with the model) is that generation seems to slow down as tokens are generated. It's much, much slower after generating 1000 tokens which LLaMA seems to go at close to the same speed up to the context length.