[MEDIUM] Better placement algorithm for pipeline parallelism using memory bandwidth + latency
This is our placement algorithm for pipeline parallelism: https://github.com/exo-explore/exo/blob/abaeb0323d4182f7bc4dd3775a8ba9209117d1cf/src/exo/master/placement_utils.py#L52-L100
It places a number of layers proportional to the memory available on each machine. This is not optimal.
In order to maximize TPS in the memory-bound regime (i.e. low batch_size), we should look at memory bandwidth and latency. The time spent computing on device i would be C_i = M * (N_i / N)/B_i where M is the total memory used by the entire model, N is the number of layers in the model, N_i is the number of layers on device i and B_i is the memory bandwidth of device i. The time spent communicating between device i and device i + 1 is L_i. Then the total time for one token for k devices is sum from 1 to k of C_i + L_i. This sum is what we want to minimize. I'l leave the details of the algorithm up to the implementer.
@AlexCheema I'll tackle this. Based on the issue description, the goal is to replace the current RAM-proportional split with a bandwidth-aware distribution to minimize total inference time.
Since the cycle (topology) is fixed during shard assignment, ∑Li is constant. Minimizing the total time ∑(Ci+Li) effectively reduces to minimizing ∑Ci.
My plan
Data Collection: I will add a lookup table mapping chip_id to memory bandwidth (e.g. 400GB/s for M3 Max) in NodePerformanceProfile, as this data is currently missing.
Algorithm Update: I will modify get_shard_assignments_for_pipeline_parallel in src/exo/master/placement_utils.py to use a greedy approach:
- Reserve: Assign 1 layer to every node in the cycle (required to prevent IndexError in auto_parallel.py where networking layers are injected).
- Sort: Sort nodes by bandwidth (Bi).
- Fill: calculate the "cost per layer" for each node, and saturate the fastest nodes with as many of the remaining layers (Ni) as their RAM permits.
- Finalize: Map these counts back to the original topology order to generate the final ShardAssignments.
I would start by adding the bandwidth specs for Apple Silicon. Let me know if you'd like any modifications to my approach to ensure a clean pr. :)
Happy for you to take this one.
Your approach looks good except it doesn’t take into account latency.
A few tweaks to take into account latency should make it work.