tensorflow-alpa icon indicating copy to clipboard operation
tensorflow-alpa copied to clipboard

Control flow support

Open HeydrichBeillschmidt opened this issue 3 years ago • 4 comments
trafficstars

[WIP] support tuple-shaped parameters for while instruction

HeydrichBeillschmidt avatar Jul 08 '22 10:07 HeydrichBeillschmidt

Hi @HeydrichBeillschmidt, when I merge your changes into my fork and try to call run_auto_sharding_pass on a simple MNIST model, I get this error:

  File "/workspaces/alpa/alpa/shard_parallel/auto_sharding.py", line 355, in run_auto_sharding_pass
    xe.run_auto_sharding(hlo_module, compile_options)
IndexError: absl::container_internal::raw_hash_map<>::at

The source of the error is the CreateStrategyVector code, where apparently a select operation has not been added to the strategy_map, and thus results in an error when iterating through the operands of the dot.278 instruction. Below is some HLO that comes from an intermediate stage of compilation, after the spmd_simplify pipeline, and before the spmd_pipeline that runs the auto sharding pass:

  broadcast.6 = f32[2048,1600]{1,0} broadcast(constant.171), dimensions={}
  select = f32[2048,1600]{1,0} select(compare.183, reshape.29, broadcast.6), metadata={op_type="Mul" op_name="mnist/sequential/dropout/dropout/Mul_1" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/backend.py" source_line=1940}
  arg34.35 = f32[1600,10]{1,0} parameter(34), parameter_replication={false}, metadata={op_name="XLA_Args"}
  dot.268 = f32[2048,10]{1,0} dot(select, arg34.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="mnist/sequential/dense/MatMul" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/layers/core/dense.py" source_line=221}

And finally, here is some logging output I've generated that shows the sequence of events leading up to this failed indexing into the strategy map:

HandleDot[0]: dot.268
CreateLeafStrategyVector: dot.268
Potential Failing operand instruction: %select = f32[2048,1600]{1,0} select(pred[2048,1600]{1,0} %compare.183, f32[2048,1600]{1,0} %reshape.29, f32[2048,1600]{1,0} %broadcast.6), metadata={op_type="Mul" op_name="mnist/sequential/dropout/dropout/Mul_1" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/backend.py" source_line=1940}

Do you have any idea what could be the problem?

tdietert avatar Aug 25 '22 22:08 tdietert

@HeydrichBeillschmidt I've solved this problem by undoing the part of the diff where you build an instruction sequence from the entry_computation->instructions() list. You passed this entry_sequence value to BuildStrategyAndCost, instead of the sequence value constructed from the hlo_live_range, but it doesn't actually contain all the instructions in the computation: https://github.com/alpa-projects/tensorflow-alpa/pull/124/files#diff-83aa23c5123bde398bcd2002e8bf5d5bdf79341e11f461715a127f9547357a13R2806

Is there a reason you did this? Replacing entry_sequence with sequence (from the hlo_live_range value, like in the master branch) passed to BuildStrategyAndCost solved my issue.

tdietert avatar Aug 26 '22 22:08 tdietert

@HeydrichBeillschmidt I've solved this problem by undoing the part of the diff where you build an instruction sequence from the entry_computation->instructions() list. You passed this entry_sequence value to BuildStrategyAndCost, instead of the sequence value constructed from the hlo_live_range, but it doesn't actually contain all the instructions in the computation: https://github.com/alpa-projects/tensorflow-alpa/pull/124/files#diff-83aa23c5123bde398bcd2002e8bf5d5bdf79341e11f461715a127f9547357a13R2806

Is there a reason you did this? Replacing entry_sequence with sequence (from the hlo_live_range value, like in the master branch) passed to BuildStrategyAndCost solved my issue.

Hi @tdietert , thank you for your issue. The BuildStrategyAndCost is designed as a recursive structure, and entry_sequence here is passed for avoiding repeated construction for instructions in computations such as while body. However, simply letting entry_sequence = entry_computation->instructions() was incorrect. The problem is addressed in the latest commit.

HeydrichBeillschmidt avatar Aug 27 '22 05:08 HeydrichBeillschmidt

@HeydrichBeillschmidt Thanks for your response! We have tried your latest changes and they work well for us, thank you. We have not validated the output, that the while loops are parallelized "correctly", but we don't experience any of the issues we experienced before.

tdietert avatar Sep 02 '22 17:09 tdietert