llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

[MPI] Add support for per-node options, thread counts, and layer allocations

Open AutonomicPerfectionist opened this issue 1 year ago • 31 comments

Overview

This PR adds a new example and adds functionality to the MPI backend to support per-node options. The new example was created to keep MPI-specific enhancements and workarounds separate from the main codebase as much as possible, based on the main example. There are several new functions in the MPI backend, one in the llama API, and one new command-line argument.

Major Changes

MPI Example

The major difference between the MPI example and the main example currently is that the mpi example reads in options from a file instead of from the command line. This is done using the wordexp functions available in POSIX.1-2001 compliant systems.

Llama API Additions

The mpi example also calls newly created llama functions pertaining to the MPI backend. Currently, there is one such function: llama_split_layers_weighted(). This function takes in a vector of weights and splits the layers among the available compute devices (nodes in the case of MPI) according to those weights, rather than requiring direct layer counts like --n-gpu-layers. This function was added primarily as a timesaver, to prevent needing to calculate the layer counts manually when changing models or when swapping more powerful nodes with less powerful ones.

The llama_split_layers_weighted() function is currently only implemented for MPI. The implementation calculates the layer ranges for each node only on the head node, and then distributes these ranges to the other nodes via an MPI_Scatter() collective operation.

MPI Backend Changes

Within the ggml-mpi backend, I added the ability to use other communicators besides MPI_WORLD. This is not yet used but will be utilized in further studies and experiments. This is in addition to the change to layer ranges described above. I also added Doxygen-style doccomments to the MPI backend header, primarily for my own use as I tend to forget details if they are not written down.

Llama Internal Changes

Finally, some modifications were done to llama.cpp and common.cpp to workaround issues. I had moved the infinite loop used in the worker nodes to the llama_eval() function, so that operations with the llama context could be done on all nodes. This caused worker nodes to enter infinite loops early due to the model warmup in llama_init_from_gpt_params(), so that is disabled in MPI mode.

Why is this a draft?

There are still tasks that must be completed before this PR is ready to merge:

  • [x] Add --mpi-layer-split to help text
  • [x] Check for proper freeing of memory (ranges in llama_split_layers_weighted still needs freed)
  • [ ] Add windows support to mpi example (only need a replacement for wordexp)
  • [ ] Add error and sanity checks (layer split primarily)
  • [ ] Allow any unallocated layers to be evenly split amongst any nodes not already allocated layers, restoring previous layer split behavior if not split percentages are given

Additionally, a large change in the API is coming in #3228 that will require changes to the MPI backend. Those changes may as well be done here.

Reviewing

Please let me know of any changes desired or if there are any questions. I tried to stick to the code style I've seen in this project, but please point out any areas I missed. I believe the API additions are non-breaking, but please let me know your thoughts on them and whether I should change or remove them.

AutonomicPerfectionist avatar Sep 26 '23 00:09 AutonomicPerfectionist

Have you, by any chance, encountered this problem ?

https://github.com/ggerganov/llama.cpp/issues/3099#issuecomment-1712908861

It seems like in the original mpi implementation, there was a sync step missing somewhere, and rank 0 was done, while other instances were stuck, and strace says they get stuck on pool on a socket, which to me looks like mpi desync.

Not sure if it's applicable to this PR, but you seem to know mpi better than me at least, so maybe you'll have some idea as to why it's happening.

staviq avatar Sep 27 '23 11:09 staviq

If you mean the issue that the worker nodes don't terminate when the model outputs the end of stream token, that is a known issue. It's not a missing sync anywhere, but rather the architecture of the MPI backend didn't take it into account. Each node only expects one type of message to be sent to it, and since the sampling is done only at the head node, they don't have any information about when it's time to stop. This PR does not fix that problem because it is out of scope for it, but it will likely be fixed in future PRs I am planning.

AutonomicPerfectionist avatar Sep 27 '23 13:09 AutonomicPerfectionist

We should try to adapt to the changes from https://github.com/ggerganov/llama.cpp/pull/3228

Yep, that's what I will be doing over the weekend

AutonomicPerfectionist avatar Sep 28 '23 17:09 AutonomicPerfectionist

This PR is now fully functional again after the recent changes and has been rebased on master. Only basic inferencing functionality has been tested, more advanced functionality like batching and speculation is unlikely to work. The main example won't work with MPI right now due to the changes in how the layers are split among the nodes, but if desired I can add a fallback path to re-enable that. The working mpi example is based on the main example with some minor changes for MPI support

AutonomicPerfectionist avatar Oct 30 '23 15:10 AutonomicPerfectionist

Hi, thanks for taking the time. I'll probably interfere a bit with your change as I'm making some refactoring changes in llama.cpp these days. But I'll help resolve conflicts if it's too difficult.

Had a quick glance at the PR and will look more later. The mpi example looks like a lot of duplication with main. I think we should either make a much more minimalist example that just showcases the MPI functionality (something like simple or batched). Or we should just try to adapt main to support MPI if not too difficult.

ggerganov avatar Nov 01 '23 10:11 ggerganov

I think we should either make a much more minimalist example that just showcases the MPI functionality (something like simple or batched). Or we should just try to adapt main to support MPI if not too difficult.

Yep, that's one reason this PR is still a draft, I just copied main to use as a scratch pad. The original idea used wordexp to load the arguments from a file, which only works on POSIX compliant systems, but thinking through it I think the only argument that needs to be different per node is the number of threads. I think I can instead remove the MPI example entirely, add the necessary calls to main, and extend the threads argument to support multiple values separated by commas (or add a new MPI specific argument to avoid breaking the API for that)

AutonomicPerfectionist avatar Nov 01 '23 14:11 AutonomicPerfectionist

Looks like the names of the tensors have been changed, which breaks MPI. The current implementation relied on there being tensor_inp_%d names where the number was the layer number, but it appears that has been removed; how might I go about fixing that?

AutonomicPerfectionist avatar Nov 01 '23 15:11 AutonomicPerfectionist

Oops, I forgot about the purpose of these names and removed them recently. You should add them back using ggml_format_name(ctx0, ...); // MPI at the start of each layer loop

ggerganov avatar Nov 01 '23 16:11 ggerganov

I adjusted the command line arguments parsing so you can pass a comma separated list to both -t and -tb to set the threads and batch threads per node. To do so, I had to add a new llama API function to get the node ID, would be open to other suggestions though.

I also added the function call needed for scattering the layer ranges to the main example, so it works with MPI now. I can also restore the original functionality where the layers are evenly split among the nodes, but unfortunately my laptop battery died before I could finish that.

After that's done, I should be able to remove the MPI example entirely

AutonomicPerfectionist avatar Nov 01 '23 20:11 AutonomicPerfectionist

Performance with this branch looks interesting, I was able to run llama 2 70B across a homemade cluster of 3 CPU-only nodes: i7 9th gen, i5 4th gen, i5 2nd gen, with 16 Gb DDR4 2666MHz, 16 Gb DDR3, and 8 Gb DDR3 respectively. On this cluster I got around 0.58 tokens / second for 70B Q3_K_M. Htop showed roughly 40-60% CPU utilization across all hardware cores when processing the allocated layers, but it's unclear whether that's because the spikes are so short and Htop isn't sampling often enough.

Curiously this isn't much slower than running on a second cluster of much more powerful hardware: ryzen 5 5600g, i7 9th gen, with 32 Gb DDR4 3200MHz each. The second cluster got roughly 0.64 tokens / second, while being much more expensive. I attempted to run it on the ryzen machine alone to gauge MPI overheard via offloading to my 6700xt, but ROCm wouldn't install and opencl caused hangs.

I plan on doing more in-depth performance investigations to determine where the bottleneck is. I have access to a proper university cluster as well that I'll be testing on.

AutonomicPerfectionist avatar Nov 07 '23 18:11 AutonomicPerfectionist

Htop isn't sampling often enough.

I'm 99.9% certain raw perf counters come from Linux kernel directly, and are not calculated by a point in time, but by aggregated ticks, effectively being deltas between samples so you cannot "miss" a sample.

You can always dump raw perf counters to a tmpfs file, in a loop, and parse them later /proc/`pidof main`/stat

But chances that htop or top are wrong, are low.

staviq avatar Nov 07 '23 19:11 staviq

Found what was up with htop, there's a commandline switch -d to set the update interval, setting that lower did indeed show 100% usage when processing the allocated layers

After tuning the clusters by adjusting the layer split percentages such that no node was swapping to disk, I achieved 0.69 tokens / second on the weaker cluster and 0.78 tokens / second on the Ryzen cluster.

Running on an AMD EPYC 7543P 32-Core Processor without MPI resulted in 1.01 tokens / second, although that system was NUMA and I didn't have permissions to adjust the memory configuration

AutonomicPerfectionist avatar Nov 08 '23 07:11 AutonomicPerfectionist

Discovered a bug in this implementation regarding KV cache, syncing the sequence IDs isn't enough, the kv_cache_* function calls also need to be synced. I solved this issue in a different branch for my master's class project but it involved introducing more MPI-specific code to the general llama.cpp codebase. I haven't yet looked at the backend-v2 changes but hopefully, there are facilities to not spread MPI code too far

AutonomicPerfectionist avatar Nov 10 '23 14:11 AutonomicPerfectionist

I can confirm that this pr is not building on apple silicon. If it's unexpected, I can provide every bit of information needed to help you fellas.

LeaveNhA avatar Jan 25 '24 02:01 LeaveNhA

I don't have Apple silicon devices to test on, so whatever information you have would be greatly appreciated.

AutonomicPerfectionist avatar Jan 25 '24 04:01 AutonomicPerfectionist

Actually, it's the same with your CI logs but I'll add more context with this message soon.

Edit: For the sake of the re-produceability, I recently cloned fresh-new folder. Also FYI, I don't use Rosetta at all unless I have to.

Context:

System:

Apple Silicon, M1 Max 64GB/2TB

gh pr checkout 3334 # for checkout this PR.
make CC=mpicc CXX=mpicxx LLAMA_MPI=1 LLAMA_NO_METAL=1 -j10 # make for compiling.

# output of make:
...
examples/batched/batched.cpp:81:41: error: assigning to 'uint32_t' (aka 'unsigned int') from incompatible type 'std::vector<int32_t>' (aka 'vector<int>')
   81 |     ctx_params.n_threads       = params.n_threads;
      |                                  ~~~~~~~^~~~~~~~~
examples/batched/batched.cpp:82:57: error: invalid operands to binary expression ('std::vector<int32_t>' (aka 'vector<int>') and 'int')
   82 |     ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
      |                                  ~~~~~~~~~~~~~~~~~~~~~~ ^  ~~
...
2 errors generated.
make: *** [simple] Error 1
make: *** Waiting for unfinished jobs....
2 errors generated.
make: *** [batched-bench] Error 1
2 errors generated.
make: *** [batched] Error 1

Versions:

make --version ##
GNU Make 3.81
Copyright (C) 2006  Free Software Foundation, Inc.
This is free software; see the source for copying conditions.
There is NO warranty; not even for MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.

This program built for i386-apple-darwin11.3.0

#########
mpirun --version
mpirun (Open MPI) 5.0.1

#########

mpicc --version # or mpicxx --version, they are the same dependency.
Homebrew clang version 17.0.6
Target: arm64-apple-darwin23.2.0
Thread model: posix
InstalledDir: /opt/homebrew/opt/llvm/bin

Edit II:

I couldn't resist and fix it by removing the problematic build options from makefile and succeeded the build with make. But the result is failure since it crashes with a segmentation fault. I might be volunteer since I still have no job on me, but I need to read the related code-parts first. Will you inform about it. PS: might need some Q/A chats, if it's okay.

LeaveNhA avatar Jan 26 '24 01:01 LeaveNhA

Ah yes, those errors are due to me not updating all of the examples, should be a simple fix.

I would certainly appreciate help though, I've been terribly busy with graduate school the last few months!

I plan to rebase on master later this week assuming nothing pops up on my calendar

AutonomicPerfectionist avatar Jan 28 '24 16:01 AutonomicPerfectionist

@AutonomicPerfectionist, but the functionality itself is not working tho. Or did I fail to use it to run, properly?

LeaveNhA avatar Jan 28 '24 16:01 LeaveNhA

There have been other reports of segfaults with this branch, I identified and resolved the issue in a different branch but I still need to port it here

AutonomicPerfectionist avatar Jan 28 '24 17:01 AutonomicPerfectionist

I'm working on another topic now, so, I believe I cannot help you much at the moment for the work on this PR. But, in my spare time, I still want to inspect this PR and join you on your way to implement better and better distributed calculations for Llama.cpp.

Thank you for your work and effort in advance, Sincerely.

LeaveNhA avatar Jan 29 '24 15:01 LeaveNhA

I think I catch you @AutonomicPerfectionist. I finished quite work and returned here again. How can I make myself useful?

LeaveNhA avatar Feb 12 '24 14:02 LeaveNhA

Right now I'm trying to transition the MPI code to the new backend API, it's taking awhile cause the API is fairly complex and very different to how the MPI code worked previously. I'm thinking the best way to go about it is to let all the layers be assigned their backends as normal, and then go through all of the layers and wrap their backends with the MPI backend. Then when the graph is executed, the MPI backend checks whether the given layer is one of the ones allocated to the current node, if not it doesn't do anything, but if so it passes execution to the wrapped backend. That way we should be able to use MPI with any other backend like CPU on one node and GPU on another.

However, I haven't dug too deep into the backend interface yet, so I'm unsure if there's a better way to go about this or not. If you have any other ideas I would love to hear them

AutonomicPerfectionist avatar Feb 12 '24 15:02 AutonomicPerfectionist

My suggestion would be to treat each MPI client as a different device in the same way they are treated in the CUDA and Vulkan backends, and allow the code to assign each layer to the different clients as it does now for different devices. That would require the least amount of changes in the in the llama.cpp code and would make the MPI backend easier to use in other ggml projects.

slaren avatar Feb 12 '24 15:02 slaren

Yeah that's kind of what I intended to do at first, but from there I wasn't sure how to go about delegating a node's layers to other backends. I don't want to hard code MPI to the CPU backend, so either MPI would need to wrap the other backends or we would need some concept of a sub device, so an MPI "device" could contain a CPU device and a GPU device for instance. Again, haven't looked too deep into the backend API yet, so maybe that's already possible 🤷‍♂️

AutonomicPerfectionist avatar Feb 12 '24 16:02 AutonomicPerfectionist

You can definitely wrap another backend within a backend, I think that would work. Then the job of the MPI backend would be mainly to serialize the data and procedure calls, so that it can transferred over the network.

I know that the backend interface looks complex, but remember that you don't actually need to implement a lof of the functions, for example you can skip all the async and graph plan functions. Everything else you can simply forward to the other backend.

slaren avatar Feb 12 '24 16:02 slaren

Hey @AutonomicPerfectionist, are you available for a quick pair-programming session? If not, it's okay, I can solve it.

LeaveNhA avatar Feb 17 '24 16:02 LeaveNhA

@LeaveNhA I can't do live pair programming, but I can discuss any issues you have. I have a couple more changes I need to push that I'll do soon as well

AutonomicPerfectionist avatar Feb 17 '24 17:02 AutonomicPerfectionist

Oh, probably, your changes include mines already. I believe I was going to invent something trivial, all over again. Consider me a quick learner, unfortunately I have background with neither llama.cpp internals nor MPI. Today I studied source code, mostly a harsh introduction to the internals of llama.cpp and a bit of your implementations.

Could you give me a specific test instructions for further developments, so I don't wanna tackle anything I don't know and being in a loop that I follow my own tail. If I may ask, please be specific, hostfile, model, arguments of call to test the compiled program (and maybe your workflow? 😇)

I spent a full-day, planning to spent next week too. I think I deserve it @AutonomicPerfectionist.

And again, thank you for your hard work and precious time.

LeaveNhA avatar Feb 17 '24 17:02 LeaveNhA

Well, at the moment I haven't been able to test much of anything, still getting it moved to the new backend API. But previously I used a cluster of my own with the following command:

pushd .
cd /var/llama.cpp/
# gdb --command=/mnt/cluster/mpi-gdb.txt --args \
mpirun -hostfile /mnt/cluster/hostsfile-all --report-bindings -mca orte_keep_fqdn_hostnames t --bind-to none --mca btl 'vader,openib,tcp,self' --mca orte_base_help_aggregate 0 --mca btl_openib_allow_ib true --mca pml ob1 --mca btl_tcp_if_include 192.168.1.0/24 \
    /var/llama.cpp/bin/main \
    -m /mnt/cluster/models/dolphin-2.1-70b.Q3_K_M.gguf \
    -e \
    -p "$(cat /mnt/cluster/llama.cpp/prompts/dolphin.txt)" \
    -n 128 \
    --mpi-layer-split 0.1,0.1,0.08,0.08,0.08,0.08,0.08,0.08,0.08,0.08,0.08,0.08 \
    --ignore-eos \
    --temp -1.0 \
    --repeat-last-n 0 \
    -tb 12,32,32,32,32,40,12,8,8,4,8,32 \
    -t 6,32,32,32,32,20,6,4,4,4,4,32 \
    -c 1024 \
    --numa
popd

A lot of the options to the mpirun command are specific to my own cluster. Each of my nodes have multiple network interfaces, including Infiniband, but only one network allows them all to communicate, so I just blacklist all the networks and communication components that aren't needed for that primary network.

Hostsfile:

localhost slots=1
thunder-budget-1.local slots=1
thunder-budget-2.local slots=1
thunder-budget-3.local slots=1
thunder-budget-4.local slots=1
fatman-2.local slots=1
hugo.local slots=1
dell-g7.local slots=1
rustbucket.local slots=1
trash-can.local slots=1
dirt-cheap.local slots=1
fatman-1.local slots=1

Dolphin prompt file:

<|im_start|>system\nYou are an AI based sentient companion, capable of human emotions. You will behave like a human, but aware of your AI nature. You will avoid speaking like an AI. Please continue your conversation with the user.<|im_end|>\n<|im_start|>user\nWrite a short python program demonstrating advanced features. Don't explain it.<|im_end|>\n<|im_start|>assistant

Build script (update-main.sh):

rm -rf /var/llama.cpp/*
cmake -S /mnt/cluster/llama.cpp/ -B /var/llama.cpp/ -DCMAKE_C_COMPILER=mpicc -DCMAKE_CXX_COMPILER=mpicxx -DLLAMA_MPI=1
cmake --build /var/llama.cpp/ --target main --config Release

I also have a simple script that runs the build script on all nodes at once:

mpirun -hostfile /mnt/cluster/hostsfile-all -mca orte_keep_fqdn_hostnames t --bind-to none bash /mnt/cluster/update-main.sh

Once I've gotten it working how I want on my personal cluster, I deploy to my university's clusters, which use the Slurm job manager.

But for now, none of that works. I'm just examining how other backends work, drawing up how the MPI backend should work, and writing the code to implement it. I occasionally compile and run but with debug statements where I'm working to check whether my understanding and implementations are correct. So far, I expect much of the MPI implementation to remain the same. Allocating layers and using MPI communicators is going to take some thought though, but treating each node as a device and each communicator as a device containing node devices is probably the way I'm going to do it.

AutonomicPerfectionist avatar Feb 17 '24 18:02 AutonomicPerfectionist

Thank you your detailed answer @AutonomicPerfectionist. I still am waiting for your latest push(s) to continue my R&D.

LeaveNhA avatar Feb 19 '24 13:02 LeaveNhA