oneflow
oneflow copied to clipboard
Add dropout1d/2d/3d api
解决:https://github.com/Oneflow-Inc/libai/issues/342#issuecomment-1207565463
- [x] 注册 nn.Dropout1d/2d/3d api。
- [x] 添加文档和接口测试。
对于PyTorch的Dropout1d/2d/3d来说,PyTorch均是使用了Bernoulli和广播乘以及Squeeze/Unsqueeze等Kernel进行拼接,没有写一个固定的fuse kernel,这个pr先对齐PyTorch的做法在Functor层实现这三个接口。 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp#L14-L75
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8880/
Speed stats:
GPU Name: GeForce GTX 1080
✔️ OneFlow resnet50 time: 128.5ms (= 12847.1ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 142.2ms (= 14221.0ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.11 (= 142.2ms / 128.5ms)
OneFlow resnet50 time: 75.3ms (= 7531.2ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.0ms (= 8397.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.11 (= 84.0ms / 75.3ms)
OneFlow resnet50 time: 48.3ms (= 9668.6ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 61.6ms (= 12315.6ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.27 (= 61.6ms / 48.3ms)
OneFlow resnet50 time: 36.0ms (= 7203.9ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 45.1ms (= 9021.2ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.25 (= 45.1ms / 36.0ms)
OneFlow resnet50 time: 28.3ms (= 5666.9ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 45.0ms (= 9003.6ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.59 (= 45.0ms / 28.3ms)
OneFlow swin dataloader time: 0.275s (= 55.085s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 30.017s / 200, num_workers=1)
Relative speed: 0.545 (= 0.150s / 0.275s)
OneFlow swin dataloader time: 0.070s (= 14.081s / 200, num_workers=4)
PyTorch swin dataloader time: 0.041s (= 8.136s / 200, num_workers=4)
Relative speed: 0.578 (= 0.041s / 0.070s)
OneFlow swin dataloader time: 0.040s (= 8.015s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.498s / 200, num_workers=8)
Relative speed: 0.561 (= 0.022s / 0.040s)
❌ OneFlow resnet50 time: 136.6ms (= 13658.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 167.8ms (= 16783.4ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.23 (= 167.8ms / 136.6ms)
OneFlow resnet50 time: 84.0ms (= 8403.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 102.5ms (= 10248.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.22 (= 102.5ms / 84.0ms)
OneFlow resnet50 time: 58.1ms (= 11627.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.4ms (= 15684.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.35 (= 78.4ms / 58.1ms)
OneFlow resnet50 time: 45.5ms (= 9099.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 77.8ms (= 15554.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.71 (= 77.8ms / 45.5ms)
OneFlow resnet50 time: 38.9ms (= 7788.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.2ms (= 13431.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.72 (= 67.2ms / 38.9ms)