tensorflow-alpa
tensorflow-alpa copied to clipboard
Control flow support
[WIP] support tuple-shaped parameters for while instruction
Hi @HeydrichBeillschmidt, when I merge your changes into my fork and try to call run_auto_sharding_pass on a simple MNIST model, I get this error:
File "/workspaces/alpa/alpa/shard_parallel/auto_sharding.py", line 355, in run_auto_sharding_pass
xe.run_auto_sharding(hlo_module, compile_options)
IndexError: absl::container_internal::raw_hash_map<>::at
The source of the error is the CreateStrategyVector code, where apparently a select operation has not been added to the strategy_map, and thus results in an error when iterating through the operands of the dot.278 instruction. Below is some HLO that comes from an intermediate stage of compilation, after the spmd_simplify pipeline, and before the spmd_pipeline that runs the auto sharding pass:
broadcast.6 = f32[2048,1600]{1,0} broadcast(constant.171), dimensions={}
select = f32[2048,1600]{1,0} select(compare.183, reshape.29, broadcast.6), metadata={op_type="Mul" op_name="mnist/sequential/dropout/dropout/Mul_1" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/backend.py" source_line=1940}
arg34.35 = f32[1600,10]{1,0} parameter(34), parameter_replication={false}, metadata={op_name="XLA_Args"}
dot.268 = f32[2048,10]{1,0} dot(select, arg34.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="mnist/sequential/dense/MatMul" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/layers/core/dense.py" source_line=221}
And finally, here is some logging output I've generated that shows the sequence of events leading up to this failed indexing into the strategy map:
HandleDot[0]: dot.268
CreateLeafStrategyVector: dot.268
Potential Failing operand instruction: %select = f32[2048,1600]{1,0} select(pred[2048,1600]{1,0} %compare.183, f32[2048,1600]{1,0} %reshape.29, f32[2048,1600]{1,0} %broadcast.6), metadata={op_type="Mul" op_name="mnist/sequential/dropout/dropout/Mul_1" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/backend.py" source_line=1940}
Do you have any idea what could be the problem?
@HeydrichBeillschmidt I've solved this problem by undoing the part of the diff where you build an instruction sequence from the entry_computation->instructions() list. You passed this entry_sequence value to BuildStrategyAndCost, instead of the sequence value constructed from the hlo_live_range, but it doesn't actually contain all the instructions in the computation: https://github.com/alpa-projects/tensorflow-alpa/pull/124/files#diff-83aa23c5123bde398bcd2002e8bf5d5bdf79341e11f461715a127f9547357a13R2806
Is there a reason you did this? Replacing entry_sequence with sequence (from the hlo_live_range value, like in the master branch) passed to BuildStrategyAndCost solved my issue.
@HeydrichBeillschmidt I've solved this problem by undoing the part of the diff where you build an instruction sequence from the
entry_computation->instructions()list. You passed thisentry_sequencevalue toBuildStrategyAndCost, instead of thesequencevalue constructed from thehlo_live_range, but it doesn't actually contain all the instructions in the computation: https://github.com/alpa-projects/tensorflow-alpa/pull/124/files#diff-83aa23c5123bde398bcd2002e8bf5d5bdf79341e11f461715a127f9547357a13R2806Is there a reason you did this? Replacing
entry_sequencewithsequence(from thehlo_live_rangevalue, like in the master branch) passed toBuildStrategyAndCostsolved my issue.
Hi @tdietert , thank you for your issue. The BuildStrategyAndCost is designed as a recursive structure, and entry_sequence here is passed for avoiding repeated construction for instructions in computations such as while body. However, simply letting entry_sequence = entry_computation->instructions() was incorrect. The problem is addressed in the latest commit.
@HeydrichBeillschmidt Thanks for your response! We have tried your latest changes and they work well for us, thank you. We have not validated the output, that the while loops are parallelized "correctly", but we don't experience any of the issues we experienced before.