papers-notebook
papers-notebook copied to clipboard
A Generic Communication Scheduler for Distributed DNN Training Acceleration
SOSP'19
https://i.cs.hku.hk/~cwu/papers/yhpeng-sosp19.pdf
来源:https://github.com/bytedance/byteps
+1
在后向计算同步梯度时,大部分框架的做法都是利用 barrier 在所有梯度都同步好之后再更新参数做下一轮前向。
但是,通过更加细粒度的同步和 Tensor 分片,这个过程其实可以优化成不需要 barrier 的,提高效率。
架构如下
这里的 CommTask,是 a communication task associated with one tensor,e.g., all-reduce。可以理解为是 Horovod 的 TensorQueue 中的一个 Message。比 Horovod 更高一级的地方在于,字节为 CommTask 设定了一个优先级。基于的观察是在反向计算的时候,梯度的计算是从后往前的,但是下一次迭代的前向是从前往后的。所以我们想优先把后向最后计算出的,首先会被同步。这样就不会 block 前向的计算。因此只要把最前面层的梯度同步,参数更新后,第一层就可以首先去做后向。
因此对于基于计算图的实现,制定优先级依赖的就是图的拓扑序。像 PyTorch 这种动态图的实现,就与 Tape 类似,根据 Tensor 的创建顺序。
Each time a communication tensor arrives, the plugin wraps it as a CommTask and assigns priority before enqueuing it. For declarative engines (e.g., TensorFlow), it uses topological sort to obtain the priority. For imperative engines, it assigns a monotonic increasing ID to each (gradient) tensor based on the order they are created (same as the BP order).
除此之外,一个 CommTask 还会被 partition,分成许多个小的 SubCommTask,这样可以以更细的粒度调度 Tensor 的通信。
看了一下开源的实现,跟论文只是共享了一些 insight,实现的方式跟论文中提到的方式有较大不同。而且没有 Auto-tune 和优先级的支持。
不过还是一个非常好的工作,国内难得
看了一下开源的实现,跟论文只是共享了一些 insight,实现的方式跟论文中提到的方式有较大不同。而且没有 Auto-tune 和优先级的支持。
不过还是一个非常好的工作,国内难得
优先级是有的 https://github.com/bytedance/byteps/blob/master/byteps/common/scheduled_queue.cc#L78-L98
谢谢指正,看到了。这里的实现和论文提到的 Tape 的那种实现类似,没有用计算图的信息,我就没找到对应的逻辑。。
优先级是根据 https://github.com/bytedance/byteps/blob/master/byteps/common/global.cc#L382 确定的
具体的实现见 https://github.com/bytedance/byteps/blob/master/byteps/tensorflow/ops.cc#L158
不过我这里有个问题,它这个地方的注释是啥意思,因为是只支持按照 Tensor 声明顺序确定优先级,所以不能支持计算图的拓扑序么?
https://github.com/bytedance/byteps/blob/948c774c30f520d8c9e36931f257da2eda386a48/byteps/tensorflow/ops.cc#L155
是这样。mxnet也是这么做的,直接按照params的声明顺序,其实对于DNN应该和拓扑顺序是一样的。
https://github.com/bytedance/byteps/blob/948c774c30/byteps/mxnet/init.py#L191-L193 (按照名称排序) https://github.com/bytedance/byteps/blob/948c774c30/byteps/mxnet/init.py#L213 (传入优先级)