oneflow
oneflow copied to clipboard
Profiling item
优化scalar_tensor.item的性能,使其耗时从pytorch的2-3倍降到pytorch的70%-80%。
优化手段
优化scalar_tensor.item的python层代码
原本的scalar_tensor.item是用python代码实现的:scalar_tensor.cpu().numpy()[0]
。不论是tensor.cpu,tensor.numpy还是numpy.getitem,对于scalar_tensor.item都过于重度,因为scalar_tensor.item整个过程描述起来非常简单:从tensor内存中获取一个标量。显然,用一个专用的c++ api来加速这个过程是很合适的。
优化scalar_tensor.item的main线程c++代码
旧版scalar_tensor.item的c++代码最终会调用AccessBlobByCallback指令。该指令构建的成本大概4-7个us,主要是一系列对象的创建成本,但其实这个构建成本可以通过复用指令对象减少到纳秒级。 上述所说的复用指令对象并不是直接复用AccessBlobByCallback指令,而是复用新设计的SyncRead指令。可以复用的根本原因是当前线程同一时间只会提交一条SyncRead指令给vm(因为是同步,main线程处理下一次SyncRead的时候,上一次SyncRead一定已经执行完)。SyncRead指令复用的方式可以采用简单的thread_local变量,确保main线程从python进入c++到最终提交指令到vm之间,不会有任何堆对象的创建。
加速SyncRead的执行
SyncRead与AccessBlobByCallback的基本逻辑大致相同,但是由于前者专用于小数据的拷贝,我们可以用临时的pined_memory做cudaAsyncMemcpy。这种方式相比于AccessBlobByCallback所用的pageable memory,能缩短cudaAsyncMemcpy操作大概5us的时间,因为省去了同步操作。
加速main线程的唤醒
旧版AccessBlobByCallback使用BlockingThenBusy来让worker线程唤醒main线程,具体做法是在cudaAsyncMemcpy之前通过BlockingThenBusy.blocking_cnt唤醒main线程,让main线程进入对BlockingThenBusy.spin_cnt的忙等,cudaAsyncMemcpy之后再对BlockingThenBusy.spin_cnt减一,让main线程执行BlockingThenBusy.spin_cnt之后的代码。这么做是为了用blocking wait减少cpu的空耗,而用busy wait加速main线程的响应。 上述过程中,worker线程调用posix接口pthread_cond_broadcast唤醒main线程。但是这个pthread_cond_broadcast也是有时间开销的,大概在1us-4us。这个开销同样可以想办法去掉,具体方法是main线程如果发觉vm上指令非常少(比如少于3条),它可以相信这些指令会在很快执行完,所以main线程可以选择不做blocking wait。这样main线程就不会进入睡眠,而worker线程也就不用花1us-4us来执行pthread_cond_broadcast。值得一提的是,原本的BlockingThenBusy不能帮助worker线程跳过pthread_cond_broadcast,本pr重构了BlockingThenBusy的内部实现以达成这一目的。
加速worker线程的响应
旧版worker线程的工作模式是blocking wait方式等待指令,执行完指令后进入下次迭代的blocking wait。这种方式下,指令从scheduler线程到worker线程的切换开销大概在2-7us的样子,这对流水很弱的eager代码非常不利。 我们可以尝试让worker线程不要那么着急睡眠的方式工作,让它在每次执行完一条执行之后在线(使用std::this_thread::yield)等待大概200us,这段时间足够让main线程把后续的eager op准备好,穿过scheduler线程,worker线程在线接到任务就立刻开始干,这样就能省去线程切换的2-7us。
Luyang Zhao 21 hours ago 1.profiler_off编译 https://oneflow-static.oss-cn-beijing.aliyuncs.com/disco_diffusion/nvidia-profile/resize_profile/s[…]ofile_off_flow_221107_master%40252ccea.nsys-rep https://oneflow-static.oss-cn-beijing.aliyuncs.com/disco_diffusion/nvidia-profile/resize_profile/s[…]f_flow_221107_profiling_item%4021763eb.nsys-rep 2.profiler_on编译 https://oneflow-static.oss-cn-beijing.aliyuncs.com/disco_diffusion/nvidia-profile/resize_profile/s[…]rofile_on_flow_221107_master%40252ccea.nsys-rep https://oneflow-static.oss-cn-beijing.aliyuncs.com/disco_diffusion/nvidia-profile/resize_profile/s[…]n_flow_221107_profiling_item%4021763eb.nsys-rep (edited)
Luyang Zhao 20 hours ago 晚上机器比较稳定的时候,又基于disco跑了多次完整测试: profiling_item 3min28s、3min33s、3min32s master 3min52s、3min51s、3min51s 看起来是有明显加速的 @lixinqi0703106 ,大概提速20s左右:weisuo:
Xinqi Li 18 hours ago 这已经很不错了,一个pr就能提升20s。 @jinhui
Xinqi Li 18 hours ago 剩下的要加速的话就只能一个一个优化常用api的main线程了,比如+-*/等。
Xinqi Li 18 hours ago 查看连续两个tensor.item。
Xinqi Li 18 hours ago profiling_item分支的: (edited) image.png
image.png
Xinqi Li 18 hours ago master分支的: image.png
image.png
Xinqi Li 18 hours ago 连续调用两个tensor.item,说明中间的间隔就是从tensor.item从main线程->scheduler线程->worker线程的总时间。profiling_item分支比main线程快了4.6倍。
Xinqi Li 18 hours ago 另外,其他op也更密集了,因为我们让worker线程就算无事可做,也忙等200us才block,这200us足够让main线程准备好下一个op。
Speed stats:
CI failed when running job: cuda-module. PR label automerge has been removed
Speed stats:
Static analysis with clang failed. PR label automerge has been removed
基于本pr@commit:https://github.com/Oneflow-Inc/oneflow/pull/9394/commits/b4b43ebfe5fc98b82ad658a332ce9cabf291c22f 测试disco,会有<10%的小概率触发程序卡住,增加export ONEFLOW_TIMEOUT_SECONDS=60
后,可以在卡住时得到如下报错:
Steps: 42%|███████████████████████████████████████ | 102/240 [02:22<03:12, 1.40s/it]
Seed used: 2954719760
Traceback (most recent call last):
File "disco.py", line 2617, in <module>
do_run()
File "disco.py", line 1239, in do_run
for j, sample in enumerate(samples):
File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/gaussian_diffusion.py", line 932, in ddim_sample_loop_progressive
out = sample_fn(
File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/gaussian_diffusion.py", line 688, in ddim_sample
out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs)
File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/respace.py", line 102, in condition_score
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/gaussian_diffusion.py", line 400, in condition_score
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/respace.py", line 128, in __call__
return self.model(x, new_ts, **kwargs)
File "disco.py", line 1155, in cond_fn
clip_in = normalize(cuts(x_in.add(1).div(2)))
File "/home/zhaoluyang/Oneflow/oneflow/python/oneflow/nn/module.py", line 158, in __call__
res = self.forward(*args, **kwargs)
File "disco.py", line 853, in forward
cutout = resize(pad_input, out_shape=output_shape)
File "/home/zhaoluyang/Oneflow/disco-diffusion/ResizeRight/resize_right.py", line 100, in resize
pad_sz, projected_grid, field_of_view = calc_pad_sz(in_sz, out_sz,
File "/home/zhaoluyang/Oneflow/disco-diffusion/ResizeRight/resize_right.py", line 158, in calc_pad_sz
pad_sz = [-field_of_view[0, 0].item(),
RuntimeError: (60 vs 60)
File "/home/zhaoluyang/Oneflow/oneflow/oneflow/api/python/utils/tensor_utils.h", line 133, in EagerLocalTensorItem
GetItemInScalarTensor<T>(tensor)
File "/home/zhaoluyang/Oneflow/oneflow/oneflow/core/framework/tensor_util.h", line 54, in GetItemInScalarTensor
GetItemInScalarTensor(scalar_tensor, &scalar, sizeof(T))
File "/home/zhaoluyang/Oneflow/oneflow/oneflow/core/framework/tensor_util.cpp", line 102, in GetItemInScalarTensor
SyncReadSmallMem(reinterpret_cast<char*>(scalar_ptr), size, local_tensor)
File "/home/zhaoluyang/Oneflow/oneflow/oneflow/core/framework/instructions_builder.cpp", line 817, in SyncAccessSmallMem
MutThreadLocalInstruction<InstructionPolicyT>(stream)
File "/home/zhaoluyang/Oneflow/oneflow/oneflow/core/framework/instructions_builder.cpp", line 777, in WaitRefCntToOne
Error Type: oneflow.ErrorProto.timeout_error
Speed stats:
基于本pr@commit:b4b43eb 测试disco,会有<10%的小概率触发程序卡住,增加
export ONEFLOW_TIMEOUT_SECONDS=60
后,可以在卡住时得到如下报错:Steps: 42%|███████████████████████████████████████ | 102/240 [02:22<03:12, 1.40s/it] Seed used: 2954719760 Traceback (most recent call last): File "disco.py", line 2617, in <module> do_run() File "disco.py", line 1239, in do_run for j, sample in enumerate(samples): File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/gaussian_diffusion.py", line 932, in ddim_sample_loop_progressive out = sample_fn( File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/gaussian_diffusion.py", line 688, in ddim_sample out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/respace.py", line 102, in condition_score return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/gaussian_diffusion.py", line 400, in condition_score eps = eps - (1 - alpha_bar).sqrt() * cond_fn( File "/home/zhaoluyang/Oneflow/disco-diffusion/guided-diffusion/guided_diffusion/respace.py", line 128, in __call__ return self.model(x, new_ts, **kwargs) File "disco.py", line 1155, in cond_fn clip_in = normalize(cuts(x_in.add(1).div(2))) File "/home/zhaoluyang/Oneflow/oneflow/python/oneflow/nn/module.py", line 158, in __call__ res = self.forward(*args, **kwargs) File "disco.py", line 853, in forward cutout = resize(pad_input, out_shape=output_shape) File "/home/zhaoluyang/Oneflow/disco-diffusion/ResizeRight/resize_right.py", line 100, in resize pad_sz, projected_grid, field_of_view = calc_pad_sz(in_sz, out_sz, File "/home/zhaoluyang/Oneflow/disco-diffusion/ResizeRight/resize_right.py", line 158, in calc_pad_sz pad_sz = [-field_of_view[0, 0].item(), RuntimeError: (60 vs 60) File "/home/zhaoluyang/Oneflow/oneflow/oneflow/api/python/utils/tensor_utils.h", line 133, in EagerLocalTensorItem GetItemInScalarTensor<T>(tensor) File "/home/zhaoluyang/Oneflow/oneflow/oneflow/core/framework/tensor_util.h", line 54, in GetItemInScalarTensor GetItemInScalarTensor(scalar_tensor, &scalar, sizeof(T)) File "/home/zhaoluyang/Oneflow/oneflow/oneflow/core/framework/tensor_util.cpp", line 102, in GetItemInScalarTensor SyncReadSmallMem(reinterpret_cast<char*>(scalar_ptr), size, local_tensor) File "/home/zhaoluyang/Oneflow/oneflow/oneflow/core/framework/instructions_builder.cpp", line 817, in SyncAccessSmallMem MutThreadLocalInstruction<InstructionPolicyT>(stream) File "/home/zhaoluyang/Oneflow/oneflow/oneflow/core/framework/instructions_builder.cpp", line 777, in WaitRefCntToOne Error Type: oneflow.ErrorProto.timeout_error
已由 f819d05ada1a413706f08cdd2234f19ee037b091 处理
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 140.3ms (= 14025.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 166.9ms (= 16691.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.19 (= 166.9ms / 140.3ms)
OneFlow resnet50 time: 86.1ms (= 8611.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 101.0ms (= 10098.2ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.17 (= 101.0ms / 86.1ms)
OneFlow resnet50 time: 58.0ms (= 11596.4ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 86.1ms (= 17211.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.48 (= 86.1ms / 58.0ms)
OneFlow resnet50 time: 44.1ms (= 8812.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.3ms (= 14054.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.59 (= 70.3ms / 44.1ms)
OneFlow resnet50 time: 41.7ms (= 8338.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.7ms (= 13332.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.60 (= 66.7ms / 41.7ms)
CI failed when running job: cuda-speed-test. PR label automerge has been removed
CI failed when running job: cuda-module. PR label automerge has been removed
Speed stats:
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 140.8ms (= 14075.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 164.5ms (= 16453.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.17 (= 164.5ms / 140.8ms)
OneFlow resnet50 time: 85.6ms (= 8561.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 101.8ms (= 10184.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.19 (= 101.8ms / 85.6ms)
OneFlow resnet50 time: 57.8ms (= 11562.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 88.4ms (= 17688.9ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.53 (= 88.4ms / 57.8ms)
OneFlow resnet50 time: 44.6ms (= 8915.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 80.0ms (= 16004.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.80 (= 80.0ms / 44.6ms)
OneFlow resnet50 time: 40.5ms (= 8108.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 76.9ms (= 15376.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.90 (= 76.9ms / 40.5ms)
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9394/
Speed stats:
GPU Name: GeForce GTX 1080
❌ OneFlow resnet50 time: 141.0ms (= 14095.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 164.4ms (= 16444.4ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.17 (= 164.4ms / 141.0ms)
OneFlow resnet50 time: 86.1ms (= 8610.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 109.6ms (= 10956.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.27 (= 109.6ms / 86.1ms)
OneFlow resnet50 time: 58.2ms (= 11635.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.6ms (= 15712.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.35 (= 78.6ms / 58.2ms)
OneFlow resnet50 time: 45.2ms (= 9047.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 70.7ms (= 14143.1ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.56 (= 70.7ms / 45.2ms)
OneFlow resnet50 time: 40.5ms (= 8106.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 76.6ms (= 15323.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.89 (= 76.6ms / 40.5ms)
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9394/