alpa icon indicating copy to clipboard operation
alpa copied to clipboard

[FEATURE] Reduce congestion of sending on one mesh

Open ZYHowell opened this issue 2 years ago • 16 comments

This can be a starting point to learn runtime_emitter and cross_mesh_resharding.

Background

In Pipeshard Parallel, when a tensor is required to be received from a mesh, we always chose the mesh that exactly generates it, has what happens here. However, when the tensor is consumed into multiple pipeline stages, a better solution MIGHT be that the later consumer receives the tensor from one of the prior consumers. For example, when stage 0 sends a tensor to stages 1, 2, and 3, but stages 1-3 don't have much communication, then it can be stage 2 receives from stage 1, and stage 3 receives from stage 2.

TODO

  • [ ] Add a pass to optimize the case above. Given a pipeline schedule and mesh allocation, the pass decides which mesh should a tensor be received from.
  • [ ] In CrossMeshCommunicator, consume that designed decision
  • [ ] In runtime emitter, chose that decision instead of using the src or using the first(using the first is even worse. It can be a bug).
  • [ ] (Optional) To do the same thing for the backward semantic, we need to do more: if stage 0 outputs t to stage 1, 2, and 3, in the backward there will be:
x = gradient of t a stage 3
x' = pipeline_end(x)
...
y = gradient of t a stage 2
y' = pipeline_end(y)
...
z = gradient of t a stage 1
z' = pipeline_end(z)
a = x' + y'
b = a + z'
grad_of_t = pipeline_start(b)
consume grad_of_t in layer 0's backward

We need to merge a = x' + y' into the backward of stage 2, b = a + z' into the backward of stage 1 by modifying code here

ZYHowell avatar Dec 06 '22 06:12 ZYHowell

Hello, please assign this to me :)

To my understanding, Ray's object store is not used for activation/param transfer, and Ray is only used for task orchestration, correct?

jon-chuang avatar Apr 04 '23 15:04 jon-chuang

right. The compilation is: wrapped Jaxprs of each pipeline stage --by CrossMeshCommunicator--> SymbolicReshardingTask --by PipelineInstEmitter--> PipelineInstruction(SEND/RECV/BROADCAST) Each pipeline inst is orchestrated by code in the collective folder, which finally calls nccl. The congestion issue may be solved by only consider the compilation part.

ZYHowell avatar Apr 04 '23 16:04 ZYHowell

For the first task, writing the pass, I will simply write a test to show the desired transform is applied to jaxpr.

As for scheduling, I guess the tensor should be queued instantly upon being received.

jon-chuang avatar Apr 04 '23 16:04 jon-chuang

One complication: are all vars uniquely named in HLO module, i.e. SSA?

jon-chuang avatar Apr 04 '23 17:04 jon-chuang

For Jaxpr, each Var has its own id(an int) unless its a DropVar(a placeholder), I'd expect most work are at this level; For HLO, each var corresponds to a specific HloInstruction

ZYHowell avatar Apr 04 '23 19:04 ZYHowell

Hello, another question: are we guaranteed that the stages are sequentially dependent? Meaning that we have a chain, not a DAG? It doesn't affect too much, but presumably, for DAG structure:

stage 0 -> stage 1
        -> stage 2

Where there is no functional dependence of stage 2 on stage 1, we should indeed broadcast to stage 1 and stage 2 from stage 0 to prevent any stalls.

However, perhaps we can ignore it for now.

jon-chuang avatar Apr 05 '23 00:04 jon-chuang

it's mainly sequential, but will have some skip connection(e.g. stage 0 -> stage 2, stage 0 -> stage 3, etc.). Otherwise we wouldn't have this issue. Besides, there are both forward and backward stages, so stage 0 and stage -1 are on the same mesh

ZYHowell avatar Apr 05 '23 00:04 ZYHowell

each Var has its own id

I presume this is unique, so I will use it as var uuid to act as a lookup key.

jon-chuang avatar Apr 05 '23 00:04 jon-chuang

you can just use var. It wraps such an id

ZYHowell avatar Apr 05 '23 00:04 ZYHowell

Another question:

Given var x on stage 0 and consumed by stage 1, but not output by stage 1, do we need to now add var x to the outvars of stage 1 to be consumed from stage 1 by a downstream stage 2?

Further, is there any cost to adding all invars to outvars of every stage by default (except messiness)?

jon-chuang avatar Apr 05 '23 01:04 jon-chuang

it depends on your algo. I think the first principle is to not increase the total comm size. E.g. if originally we send 0>2, I cannot see any advantage in making it 0>1>2. The case in the issue is: 0>1>2 is better than (0>1 and 0>2). In addition, if 2 sends x to 0 and 1 sends y to 0, but 0 only uses x + y, we can make it be 2 sends x to 1 and 1 sends x+y to 0.

Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars. Besides, the messiness itself might influence later passes so we'd hope to avoid it.

ZYHowell avatar Apr 05 '23 01:04 ZYHowell

The algo is a simple one. It is take from last seen stage:

last_seen = {}
# Sequentially walk the stages
for stage in stages:
  for (src, var) in cross_deps[stage.id]:
    # If var is a dep, check if we have already read from it.
    # If so, add to outvars of that stage and fetch from the latest stage.
    if var in cache:
      src_mesh = meshes[last_seen[var]]
      upsert(src_mesh.outvars, var)
      last_seen[var] = stage.id
    else:
      last_seen[var] = stage.id
      src_mesh = src

Is adding to outvars necessary? It seems that in our case, we don't need to add to outvars, we should be able to fetch from the invars?

invars -> [model] -> outvars ==cross-shard==>
 |======================cross-shard=>

This would mean that the cross-shard invars can begin to be re-sharded prior to the model invocation.

However, not sure if outvars is merely logical, and we can facilitate the same async transfer as soon as one of the outvars is ready, as marked by the runtime.

Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars.

I will avoid this then.

jon-chuang avatar Apr 05 '23 01:04 jon-chuang

The heuristic works for the first scene. In the above (0>1 & 0>2) case, we don't need to add it in 1's outvars. You can read the PipelineInstEmitter for more details how we actually launch send/recv and free. We've already done the async transfer with cuda event and some kernel injected to record the event.

ZYHowell avatar Apr 05 '23 01:04 ZYHowell

Btw, as far as producer goes, every var corresponds to e.g. a single tensor sharded across the entire submesh, correct?

Anw, adding an invar to the outvars is non-trivial. One has to deal with donation, and also might need to recompile the jaxpr. Prefer if the transfer takes a different pathway to piggybacking on outvars. Any suggestions?

jon-chuang avatar Apr 05 '23 05:04 jon-chuang

In addition, if 2 sends x to 0 and 1 sends y to 0, but 0 only uses x + y, we can make it be 2 sends x to 1 and 1 sends x+y to 0.

This seems more complicated. For an SQL database person, it sounds like expression pushdown.

Sounds like we really do want to reconfigure the jaxpr after pipelining but before sharding. So at the pipelining pass, we should use our last seen stage heuristic to force relevant invars to become outvars. Not sure if behaviour should be to skip invar if it is donated before this pass.

We've already done the async transfer with cuda event and some kernel injected to record the event.

However, I don't understand how the async queueing would occur in this case. Will every jaxpr outvar be evaluated and queued async concurrently?

jon-chuang avatar Apr 05 '23 05:04 jon-chuang

A var corresponds to a logical tensor including all its shards.

In the "pipeline pass", we only decide how the computational graph is divided into pipeline stages, but not the communication between pipeline stages.

Instead, in PipelineInstEmitter, we create a schedule of each device mesh, where we manages the execution of each communication, computation, and memory deallocation(tensors are allocated only with computation, and communication reuses those allocated tensors to receive). At there, we store the live tensors of each mesh in PipelineInstEmitterHelper at each time tick.

So the way I think about the forward part is: for the forward part, we only modify code in PipelineInstEmitter to emit send/recv from a mesh with the least traffic among all meshes having the tensor. For the backward part, things are more complicated, there might be something related to the "pipeline pass"

I'd suggest you read the following for more details: the architecture section in the project's doc and https://github.com/alpa-projects/alpa/blob/1ddb2dc30575b218a38937326682e020881dbc8e/alpa/pipeline_parallel/runtime_emitter.py#L545-L591.

For overlapping communication and computation, please refer to https://github.com/alpa-projects/alpa/pull/773 and https://github.com/alpa-projects/tensorflow-alpa/pull/127

ZYHowell avatar Apr 05 '23 19:04 ZYHowell