stylable icon indicating copy to clipboard operation
stylable copied to clipboard

tensorflow.push_pull CANNOT calculate AVERAGE, it always gives SUM

Open HugoZHL opened this issue 5 years ago • 3 comments

Describe the bug In Tensorflow push_pull function, op can never equal Average since op is a value from Enum class while Average is just a str. This bug causes the push_pull function in Tensorflow to always calculate the SUM of tensors instead of AVERAGE.

To Reproduce Steps to reproduce the behavior:

  1. add a log print('Type of op and Average: ', type(op), type(Average)) after line op = handle_average_backwards_compatibility(op, average) in tensorflow.push_pull function
  2. run whatever a program that uses tensorflow.push_pull, e.g. codes that use DistributedOptimizer.
  3. See the types of op and Average, which will be printed during the setup of Tensorflow graph
  4. The log will show Type of op and Average: <enum 'ReduceOps'> <class 'str'>

Expected behavior Here we cannot compare enum type with str type. There're two solutions to this problem:

  1. use op.value to compare with Average in tensorflow.push_pull function
  2. change ReduceOps class from a Enum class to a normal class

Screenshots here in tensorflow.ops the ReduceOps inherits Enum image

here in tensorflow.__init__ the program directly compares op (which is RecudeOps.Average) to Average (which is a str) image

the function uses op == Average to judge whether calculating average of tensors or sum of tensors now it can only give sum of tensors.

Environment (please complete the following information): Whatever

  • OS:
  • GCC version:
  • CUDA and NCCL version:
  • Framework (TF, PyTorch, MXNet):

Additional context Add any other context about the problem here.

HugoZHL avatar Nov 20 '20 16:11 HugoZHL

@pleasantrabbit I think we can support Average easily

bobzhuyb avatar Nov 20 '20 18:11 bobzhuyb

@pleasantrabbit I think we can support Average easily

Yes, we'll add Average support.

pleasantrabbit avatar Nov 20 '20 19:11 pleasantrabbit

https://github.com/bytedance/byteps/pull/324 here's the pull request to fix this bug.

HugoZHL avatar Nov 21 '20 13:11 HugoZHL