[BOUNTY - $500] Better PartitioningStrategy
Introduction
exo currently implements Pipeline Parallel inference. This splits up layers of a model over multiple devices and executes them sequentially, device-by-device.
There are different ways we can split up the model layers. For this purpose, exo defines something called a PartitioningStrategy:
https://github.com/exo-explore/exo/blob/5e0db20426b51678149b12bb7cd55aeb3d1935c1/exo/topology/partitioning_strategy.py#L16-L19
This takes a Topology and gives a List[Partition].
A Partition consists of node_id, start and end. The Partitions must be continuous ranges [start, end). The first start must be 0. The last end must be 1.
There's two things going on here:
- It decides the order that nodes execute and send messages between each other. For example, if you return
[node1, node2, node3], then node1 will execute first, followed by node2, followed by node3, which will then send an output token tonode1to continue the cycle. However, if you return[node2, node1, node3]then node2 will execute first, followed by node1, followed by node3, which will then send an output token tonode2to continue the cycle. - It decides how many layers each node gets. Each node gets a number of layers proportional to
end-start. For example ifstart=0,end=1then that node will get all the layers. Ifstart=0,end=0.5for node1 andstart=0.5,end=1for node2 then node1 will get 50% of the layers and node2 will get 50% of the layers.
The default and only PartitioningStrategy right now is RingMemoryWeightedPartitioningStrategy:
https://github.com/exo-explore/exo/blob/5e0db20426b51678149b12bb7cd55aeb3d1935c1/exo/topology/ring_memory_weighted_partitioning_strategy.py#L7-L18
What this does is it sorts primarily by memory, secondarily by node_id. The size of each partition is proportional to the memory of the device, i.e. if deviceA has 4GB memory and deviceB has 6GB memory, deviceA will get 40% of the layers and deviceB will get 60% of the layers (modulo some rounding). Note that it's important that we sort secondarily by node_id to ensure deterministic and consistent sorting in the case that memory is the same for two devices.
The task
The task is to implement a new, improved PartitioningStrategy that takes into account more than just memory. This may require augmenting the Topology class with more information that it currently has, which will require changes across the codebase. Some things you might want to consider here are: device FLOPS and inter-node latency. There are many other things you could take into account here which I will leave to you to decide.
I have some ideas for how to do this, and there are many potential approaches however I'm looking for out of the box ideas here.
I'll leave it up to you to reason about how to lay this out, but there are two high level metrics that would make sense to optimise for (should they be optimised separately or together?):
- Time-to-first token (latency)
- Tokens per second (throughput)
Deliverables
- A set of unit tests for your new
PartitioningStrategythat show it works in different cases. - A set of unit tests that "simulate" different scenarios and show that this
PartitioningStrategyachieves the optimal solution in each scenario. - An option added to the main script to enable this
PartitioningStrategy(you decide if other parameters should be added to configure the newPartitioningStrategy
Will be testing the strategy over the weekend on my cluster of MacBooks.
Hi @AlexCheema, I've seen that you usually create bounties on issues, maybe you're interested in using Opire. You don't pay until someone claims the bounties with a PR.
PS: I'm the cofounder, so if you need anything, feel free to contact me
The new partitioning strategy should consider not only memory but also other parameters such as:
Device FLOPS (Floating Point Operations Per Second) Inter-node latency Key Metrics for Optimization Time-to-first token (latency) Tokens per second (throughput)
Proposed Structure
def partition(self, topology: Topology) -> List[Partition]:
nodes = list(topology.all_nodes())
# Sorting nodes by multiple criteria
nodes.sort(key=lambda x: (x[1].flops, -x[1].latency, -x[1].memory, x[0]))
total_memory = sum(node[1].memory for node in nodes)
partitions = []
start = 0
for node in nodes:
end = round(start + (node[1].memory / total_memory), 5)
partitions.append(Partition(node[0], start, end))
start = end
return partitions
Explanation Node Sorting: Nodes are sorted by FLOPS, then by latency (in descending order), then by memory (in descending order), and finally by node ID. This ensures that more powerful nodes with lower latency are processed first.
Proportional Distribution: As in the current implementation, the size of each partition will be proportional to the device's memory.
Additional Parameters To configure the new partitioning strategy, the following parameters can be added to the main script:
FLOPS Priority Coefficient: Allows the user to adjust how heavily FLOPS influence the partitioning. Latency Priority Coefficient: Allows the user to adjust the influence of latency on the partitioning. Testing Unit Tests for Functionality: Create tests that verify the correctness of partitioning based on various inputs.
Scenario Tests for Simulation: Develop tests that simulate different load scenarios and demonstrate that the new partitioning strategy achieves optimal solutions.
Conclusion This new partitioning strategy should enhance system performance by taking into account both the computational capabilities of devices and network latencies.