alpa
alpa copied to clipboard
[FEATURE] Reduce congestion of sending on one mesh
This can be a starting point to learn runtime_emitter
and cross_mesh_resharding
.
Background
In Pipeshard Parallel, when a tensor is required to be received from a mesh, we always chose the mesh that exactly generates it, has what happens here. However, when the tensor is consumed into multiple pipeline stages, a better solution MIGHT be that the later consumer receives the tensor from one of the prior consumers. For example, when stage 0 sends a tensor to stages 1, 2, and 3, but stages 1-3 don't have much communication, then it can be stage 2 receives from stage 1, and stage 3 receives from stage 2.
TODO
- [ ] Add a pass to optimize the case above. Given a pipeline schedule and mesh allocation, the pass decides which mesh should a tensor be received from.
- [ ] In
CrossMeshCommunicator
, consume that designed decision - [ ] In runtime emitter, chose that decision instead of using the src or using the first(using the first is even worse. It can be a bug).
- [ ] (Optional) To do the same thing for the backward semantic, we need to do more: if stage 0 outputs
t
to stage 1, 2, and 3, in the backward there will be:
x = gradient of t a stage 3
x' = pipeline_end(x)
...
y = gradient of t a stage 2
y' = pipeline_end(y)
...
z = gradient of t a stage 1
z' = pipeline_end(z)
a = x' + y'
b = a + z'
grad_of_t = pipeline_start(b)
consume grad_of_t in layer 0's backward
We need to merge a = x' + y' into the backward of stage 2, b = a + z' into the backward of stage 1 by modifying code here
Hello, please assign this to me :)
To my understanding, Ray's object store is not used for activation/param transfer, and Ray is only used for task orchestration, correct?
right. The compilation is:
wrapped Jaxpr
s of each pipeline stage --by CrossMeshCommunicator
--> SymbolicReshardingTask
--by PipelineInstEmitter
--> PipelineInstruction
(SEND
/RECV
/BROADCAST
)
Each pipeline inst is orchestrated by code in the collective
folder, which finally calls nccl. The congestion issue may be solved by only consider the compilation part.
For the first task, writing the pass, I will simply write a test to show the desired transform is applied to jaxpr.
As for scheduling, I guess the tensor should be queued instantly upon being received.
One complication: are all vars uniquely named in HLO module, i.e. SSA?
For Jaxpr
, each Var
has its own id(an int) unless its a DropVar
(a placeholder), I'd expect most work are at this level;
For HLO, each var corresponds to a specific HloInstruction
Hello, another question: are we guaranteed that the stages are sequentially dependent? Meaning that we have a chain, not a DAG? It doesn't affect too much, but presumably, for DAG structure:
stage 0 -> stage 1
-> stage 2
Where there is no functional dependence of stage 2 on stage 1, we should indeed broadcast to stage 1 and stage 2 from stage 0 to prevent any stalls.
However, perhaps we can ignore it for now.
it's mainly sequential, but will have some skip connection(e.g. stage 0 -> stage 2, stage 0 -> stage 3, etc.). Otherwise we wouldn't have this issue. Besides, there are both forward and backward stages, so stage 0 and stage -1 are on the same mesh
each Var has its own id
I presume this is unique, so I will use it as var uuid to act as a lookup key.
you can just use var. It wraps such an id
Another question:
Given var x
on stage 0 and consumed by stage 1, but not output by stage 1, do we need to now add var x
to the outvars
of stage 1
to be consumed from stage 1
by a downstream stage 2
?
Further, is there any cost to adding all invars
to outvars
of every stage by default (except messiness)?
it depends on your algo. I think the first principle is to not increase the total comm size. E.g. if originally we send 0>2, I cannot see any advantage in making it 0>1>2. The case in the issue is: 0>1>2 is better than (0>1 and 0>2). In addition, if 2 sends x
to 0 and 1 sends y
to 0, but 0 only uses x + y
, we can make it be 2 sends x
to 1 and 1 sends x+y
to 0.
Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars. Besides, the messiness itself might influence later passes so we'd hope to avoid it.
The algo is a simple one. It is take from last seen stage
:
last_seen = {}
# Sequentially walk the stages
for stage in stages:
for (src, var) in cross_deps[stage.id]:
# If var is a dep, check if we have already read from it.
# If so, add to outvars of that stage and fetch from the latest stage.
if var in cache:
src_mesh = meshes[last_seen[var]]
upsert(src_mesh.outvars, var)
last_seen[var] = stage.id
else:
last_seen[var] = stage.id
src_mesh = src
Is adding to outvars necessary? It seems that in our case, we don't need to add to outvars, we should be able to fetch from the invars?
invars -> [model] -> outvars ==cross-shard==>
|======================cross-shard=>
This would mean that the cross-shard invars can begin to be re-sharded prior to the model invocation.
However, not sure if outvars is merely logical, and we can facilitate the same async transfer as soon as one of the outvars is ready, as marked by the runtime.
Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars.
I will avoid this then.
The heuristic works for the first scene. In the above (0>1 & 0>2) case, we don't need to add it in 1's outvars. You can read the PipelineInstEmitter for more details how we actually launch send/recv and free. We've already done the async transfer with cuda event and some kernel injected to record the event.
Btw, as far as producer goes, every var corresponds to e.g. a single tensor sharded across the entire submesh, correct?
Anw, adding an invar to the outvars is non-trivial. One has to deal with donation, and also might need to recompile the jaxpr. Prefer if the transfer takes a different pathway to piggybacking on outvars
. Any suggestions?
In addition, if 2 sends x to 0 and 1 sends y to 0, but 0 only uses x + y, we can make it be 2 sends x to 1 and 1 sends x+y to 0.
This seems more complicated. For an SQL database person, it sounds like expression pushdown.
Sounds like we really do want to reconfigure the jaxpr after pipelining but before sharding. So at the pipelining pass, we should use our last seen stage
heuristic to force relevant invars
to become outvars
. Not sure if behaviour should be to skip invar
if it is donated before this pass.
We've already done the async transfer with cuda event and some kernel injected to record the event.
However, I don't understand how the async queueing would occur in this case. Will every jaxpr outvar
be evaluated and queued async concurrently?
A var corresponds to a logical tensor including all its shards.
In the "pipeline pass", we only decide how the computational graph is divided into pipeline stages, but not the communication between pipeline stages.
Instead, in PipelineInstEmitter
, we create a schedule of each device mesh, where we manages the execution of each communication, computation, and memory deallocation(tensors are allocated only with computation, and communication reuses those allocated tensors to receive). At there, we store the live tensors of each mesh in PipelineInstEmitterHelper
at each time tick.
So the way I think about the forward part is: for the forward part, we only modify code in PipelineInstEmitter
to emit send/recv from a mesh with the least traffic among all meshes having the tensor. For the backward part, things are more complicated, there might be something related to the "pipeline pass"
I'd suggest you read the following for more details: the architecture section in the project's doc and https://github.com/alpa-projects/alpa/blob/1ddb2dc30575b218a38937326682e020881dbc8e/alpa/pipeline_parallel/runtime_emitter.py#L545-L591.
For overlapping communication and computation, please refer to https://github.com/alpa-projects/alpa/pull/773 and https://github.com/alpa-projects/tensorflow-alpa/pull/127