oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

Add dropout1d/2d/3d api

Open BBuf opened this issue 1 year ago • 1 comments

解决:https://github.com/Oneflow-Inc/libai/issues/342#issuecomment-1207565463

  • [x] 注册 nn.Dropout1d/2d/3d api。
  • [x] 添加文档和接口测试。

图片

图片

图片

BBuf avatar Aug 08 '22 15:08 BBuf

对于PyTorch的Dropout1d/2d/3d来说,PyTorch均是使用了Bernoulli和广播乘以及Squeeze/Unsqueeze等Kernel进行拼接,没有写一个固定的fuse kernel,这个pr先对齐PyTorch的做法在Functor层实现这三个接口。 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp#L14-L75

BBuf avatar Aug 10 '22 01:08 BBuf

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8880/

github-actions[bot] avatar Aug 19 '22 09:08 github-actions[bot]

Speed stats:
GPU Name: GeForce GTX 1080 

✔️ OneFlow resnet50 time: 128.5ms (= 12847.1ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 142.2ms (= 14221.0ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.11 (= 142.2ms / 128.5ms)

OneFlow resnet50 time: 75.3ms (= 7531.2ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.0ms (= 8397.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.11 (= 84.0ms / 75.3ms)

OneFlow resnet50 time: 48.3ms (= 9668.6ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 61.6ms (= 12315.6ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.27 (= 61.6ms / 48.3ms)

OneFlow resnet50 time: 36.0ms (= 7203.9ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 45.1ms (= 9021.2ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.25 (= 45.1ms / 36.0ms)

OneFlow resnet50 time: 28.3ms (= 5666.9ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 45.0ms (= 9003.6ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.59 (= 45.0ms / 28.3ms)

OneFlow swin dataloader time: 0.275s (= 55.085s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 30.017s / 200, num_workers=1)
Relative speed: 0.545 (= 0.150s / 0.275s)

OneFlow swin dataloader time: 0.070s (= 14.081s / 200, num_workers=4)
PyTorch swin dataloader time: 0.041s (= 8.136s / 200, num_workers=4)
Relative speed: 0.578 (= 0.041s / 0.070s)

OneFlow swin dataloader time: 0.040s (= 8.015s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.498s / 200, num_workers=8)
Relative speed: 0.561 (= 0.022s / 0.040s)

❌ OneFlow resnet50 time: 136.6ms (= 13658.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 167.8ms (= 16783.4ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.23 (= 167.8ms / 136.6ms)

OneFlow resnet50 time: 84.0ms (= 8403.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 102.5ms (= 10248.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.22 (= 102.5ms / 84.0ms)

OneFlow resnet50 time: 58.1ms (= 11627.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.4ms (= 15684.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.35 (= 78.4ms / 58.1ms)

OneFlow resnet50 time: 45.5ms (= 9099.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 77.8ms (= 15554.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.71 (= 77.8ms / 45.5ms)

OneFlow resnet50 time: 38.9ms (= 7788.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.2ms (= 13431.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.72 (= 67.2ms / 38.9ms)

github-actions[bot] avatar Aug 19 '22 09:08 github-actions[bot]