oneflow
oneflow copied to clipboard
[Feature Request]: Compile graph in parallel - 并行编译nn.Graph以加速整个推理流程
Background and motivation
Currently, oneflow only compile graph with a single thread, which slow down the construction of a graph. 目前oneflow只使用单线程进行编译,某些时候(比如希望调整BATCH_SIZE之类的参数防止爆显存的时候)单线程编译会很慢 it takes 5 minutes to compile a deep LSTM net 需要5分钟才能完成一个LSTM网络的搭建 is it possible to enable parallel compiling? 有可能允许oneflow并行编译吗?
API Proposal
class oneflow.nn.Graph
def set_compile_threads(n:int=1):
'''
set number of threads while compiling the graph
Parameters:
n (int, optional) - The default value is 1, set to -1 would use all of the available threads, set to 0 would use all of the physical core. Otherwise oneflow would use `n` threads to compile this graph.
'''
API Usage
# suppose graph is a oneflow.nn.Graph
# before compile the graph
graph.set_compile_threads()
# compile the graph when it is called
Alternatives
cache compile information to speed up further compiling step 保存编译信息(目前只能保存state_dict)以加速未来的构建过程
Risks
I have no idea whether multi-threading generates sub-optimal results. 我不确定多线程编译会不会造成编译图的运行效率下降。
it takes 5 minutes to compile a deep LSTM net
Is this a single device task or multi-device task?
Compilation time costs of each compilation stage can be shown by :
- Enable this environment variable
export ONEFLOW_DEBUG_MODE=1 - Run the graph
- After the first call of the graph, the compilation has finished, compilation time costs were logged into
log/local_rank_0/machine-xxx/oneflow.INFO
Please send the oneflow.INFO log here, so we can specify the time cost.
is it possible to enable parallel compiling?
If this is a multi-device task, there is a parallel compiling mode which has been tested and will be merged into master soon: https://github.com/Oneflow-Inc/oneflow/pull/9920
Is this a single device task or multi-device task?
It is a single device task.
I put 12 LSTM in the net, using the eval/exec to generate the full network
def __init__(...):
...
self.layer=6
for i in range(self.layer):
exec(f"self.{'layer'+str(i)+'_i'}=nn.LSTM({self.io},{self.h},bias=False)")
exec(f"self.{'layer'+str(i)+'_o'}=nn.LSTM({self.h},{self.io},bias=True)")
...
def forward(...):
...
for i in range(self.layer):
(lstm_in,lstm_out)=eval(f"""(self.{'layer'+str(i)+'_i'},self.{'layer'+str(i)+'_o'})""")
inner,lst[i*2]=lstm_in(L1,lst[i*2])
output,lst[i*2+1]=lstm_out(inner,lst[i*2+1])
L1=L1+output
It would take 06:55 to finish the first epoch, and only 03:18 to finish the second epoch.
The log might be sent later, since I am trainning the net and I don't want to interrupt the training procedure.
The problem might be the seq_length, input.shape[0] for nn.LSTM(batch_first=False), I change the sequence length from 150 to 300 and it takes about 20 minutes to finish compiling the graph and unexpectly high memory usage(67.4GB, same structure with seq_length=150 could directly trained in a nvidia RTX 3060).
The full log is here (to not interrupt the GPU training thread, I use CPU to execute the graph.)
Log file created at: 2023/04/25 00:39:33
Running on machine: 3060
Running duration (h:mm:ss): 0:00:00
Log line format: [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
I20230425 00:39:33.407505 20443 env_global_objects_scope.cpp:165] Using rpc backend: local
I20230425 00:39:33.599052 20443 epoll_comm_network.cpp:63] CommNet:Epoll listening on 0.0.0.0:42027
I20230425 00:39:34.416039 20443 version.cpp:22] OneFlow git version: 99b71264
I20230425 00:39:34.416098 20443 cuda_device_manager_factory.cpp:63] CUDA runtime version: 11.8
I20230425 00:39:34.416129 20443 cuda_device_manager_factory.cpp:72] cuDNN version: 8.9.0
I20230425 00:39:34.416134 20443 cuda_device_manager_factory.cpp:85] NCCL version: 2.15.1
I20230425 00:48:47.266355 20443 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 OptimizationLogicalGraph","mem_rss":"11621.000000 MB","time_cost":"433 seconds"}
I20230425 00:49:50.931952 20443 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 AlignStates","mem_rss":"11623.000000 MB","time_cost":"47 seconds"}
I20230425 00:54:51.325140 20443 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 CompleteJob","mem_rss":"11606.000000 MB","time_cost":"300 seconds"}
I20230425 00:56:40.212065 20443 plan_util.cpp:1132]
Graph name Graph_0 in Rank: -1, Device: -1 needs to allocate [ 0 MiB ] device memory.
In general, Chunk id: -1 memory is [ -1e-06 MiB ] with mem_block_num = 0
Unreused memory not eager var is [ 0 MiB ] with mem_block_num = 0
Eager Variable Tensor total memory is [ 0 MiB ] with mem_block_num = 0
I20230425 00:56:44.909133 20443 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 CompilePlan","mem_rss":"17494.000000 MB","time_cost":"98 seconds"}
I20230425 00:56:44.910989 20443 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 SyncPlan","mem_rss":"17494.000000 MB","time_cost":"0 seconds"}
Further, a >200MiB forward_graph0 is generated.
part of it is (use ... to omit some duplicates)
module_name2module_conf {
key: "model.layer5_i"
value {
name: "model.layer5_i"
ops: "model.layer5_i.weight_ih_l0"
ops: "model.layer5_i.weight_hh_l0"
ops: "model.layer5_i-narrow-62321"
ops: "model.layer5_i-squeeze-62322"
ops: "model.layer5_i-narrow-62323"
ops: "model.layer5_i-squeeze-62324"
ops: "model.layer5_i-narrow-62325"
ops: "model.layer5_i-narrow-62326"
ops: "model.layer5_i-narrow-62327"
ops: "model.layer5_i-narrow-62328"
ops: "model.layer5_i-narrow-62329"
ops: "model.layer5_i-narrow-62330"
ops: "model.layer5_i-narrow-62331"
ops: "model.layer5_i-narrow-62332"
...
ops: "model.layer5_i-narrow-62621"
ops: "model.layer5_i-narrow-62622"
ops: "model.layer5_i-narrow-62623"
ops: "model.layer5_i-squeeze-62624"
ops: "model.layer5_i-squeeze-62625"
ops: "model.layer5_i-squeeze-62626"
ops: "model.layer5_i-squeeze-62627"
...
ops: "model.layer5_i-squeeze-62918"
ops: "model.layer5_i-squeeze-62919"
ops: "model.layer5_i-squeeze-62920"
ops: "model.layer5_i-squeeze-62921"
ops: "model.layer5_i-squeeze-62922"
ops: "model.layer5_i-matmul-62923"
ops: "model.layer5_i-matmul-62924"
ops: "model.layer5_i-add_n-62925"
ops: "model.layer5_i-narrow-62926"
ops: "model.layer5_i-narrow-62927"
ops: "model.layer5_i-narrow-62928"
ops: "model.layer5_i-narrow-62929"
ops: "model.layer5_i-sigmoid-62930"
ops: "model.layer5_i-sigmoid-62931"
ops: "model.layer5_i-tanh-62932"
ops: "model.layer5_i-sigmoid-62933"
ops: "model.layer5_i-broadcast_mul-62934"
ops: "model.layer5_i-broadcast_mul-62935"
ops: "model.layer5_i-add_n-62936"
ops: "model.layer5_i-tanh-62937"
ops: "model.layer5_i-broadcast_mul-62938"
ops: "model.layer5_i-matmul-62939"
ops: "model.layer5_i-matmul-62940"
ops: "model.layer5_i-add_n-62941"
ops: "model.layer5_i-narrow-62942"
ops: "model.layer5_i-narrow-62943"
ops: "model.layer5_i-narrow-62944"
ops: "model.layer5_i-narrow-62945"
ops: "model.layer5_i-sigmoid-62946"
ops: "model.layer5_i-sigmoid-62947"
ops: "model.layer5_i-tanh-62948"
ops: "model.layer5_i-sigmoid-62949"
ops: "model.layer5_i-broadcast_mul-62950"
ops: "model.layer5_i-broadcast_mul-62951"
ops: "model.layer5_i-add_n-62952"
ops: "model.layer5_i-tanh-62953"
ops: "model.layer5_i-broadcast_mul-62954"
...
ops: "model.layer5_i-matmul-67691"
ops: "model.layer5_i-matmul-67692"
ops: "model.layer5_i-add_n-67693"
ops: "model.layer5_i-narrow-67694"
ops: "model.layer5_i-narrow-67695"
ops: "model.layer5_i-narrow-67696"
ops: "model.layer5_i-narrow-67697"
ops: "model.layer5_i-sigmoid-67698"
ops: "model.layer5_i-sigmoid-67699"
ops: "model.layer5_i-tanh-67700"
ops: "model.layer5_i-sigmoid-67701"
ops: "model.layer5_i-broadcast_mul-67702"
ops: "model.layer5_i-broadcast_mul-67703"
ops: "model.layer5_i-add_n-67704"
ops: "model.layer5_i-tanh-67705"
ops: "model.layer5_i-broadcast_mul-67706"
ops: "model.layer5_i-stack-67707"
ops: "model.layer5_i-stack-67708"
ops: "model.layer5_i-stack-67709"
ops: "model.layer5_i-cat-67710"
ops: "model.layer5_i-expand_dims-67711"
ops: "model.layer5_i-expand_dims-67712"
ops: "model.layer5_i-split_like-83631"
ops: "model.layer5_i-stack_grad-83632"
ops: "model.layer5_i-stack_grad-83633"
ops: "model.layer5_i-stack_grad-83634"
ops: "model.layer5_i-broadcast_mul-83635"
ops: "model.layer5_i-broadcast_mul-83636"
ops: "model.layer5_i-sigmoid_grad-83637"
ops: "model.layer5_i-tanh_grad-83638"
ops: "model.layer5_i-narrow_grad-83639"
ops: "model.layer5_i-broadcast_mul-83640"
ops: "model.layer5_i-broadcast_mul-83641"
ops: "model.layer5_i-broadcast_mul-83642"
ops: "model.layer5_i-broadcast_mul-83643"
ops: "model.layer5_i-sigmoid_grad-83644"
ops: "model.layer5_i-sigmoid_grad-83645"
ops: "model.layer5_i-tanh_grad-83646"
ops: "model.layer5_i-narrow_grad-83647"
ops: "model.layer5_i-add_n-83648"
ops: "model.layer5_i-narrow_grad-83649"
ops: "model.layer5_i-add_n-83650"
ops: "model.layer5_i-narrow_grad-83651"
ops: "model.layer5_i-add_n-83652"
ops: "model.layer5_i-matmul-83653"
ops: "model.layer5_i-matmul-83654"
ops: "model.layer5_i-add_n-83655"
ops: "model.layer5_i-matmul-83656"
ops: "model.layer5_i-matmul-83657"
ops: "model.layer5_i-broadcast_mul-83658"
ops: "model.layer5_i-broadcast_mul-83659"
ops: "model.layer5_i-reshape_like-83660"
ops: "model.layer5_i-sigmoid_grad-83661"
ops: "model.layer5_i-tanh_grad-83662"
ops: "model.layer5_i-add_n-83663"
ops: "model.layer5_i-narrow_grad-83664"
ops: "model.layer5_i-add_n-83665"
ops: "model.layer5_i-narrow_grad-83666"
ops: "model.layer5_i-broadcast_mul-83667"
ops: "model.layer5_i-broadcast_mul-83668"
ops: "model.layer5_i-broadcast_mul-83669"
ops: "model.layer5_i-broadcast_mul-83670"
ops: "model.layer5_i-sigmoid_grad-83671"
ops: "model.layer5_i-sigmoid_grad-83672"
ops: "model.layer5_i-tanh_grad-83673"
ops: "model.layer5_i-narrow_grad-83674"
ops: "model.layer5_i-add_n-83675"
ops: "model.layer5_i-narrow_grad-83676"
ops: "model.layer5_i-add_n-83677"
ops: "model.layer5_i-narrow_grad-83678"
ops: "model.layer5_i-add_n-83679"
ops: "model.layer5_i-matmul-83680"
ops: "model.layer5_i-matmul-83681"
ops: "model.layer5_i-add_n-83682"
ops: "model.layer5_i-add_n-83683"
ops: "model.layer5_i-matmul-83684"
ops: "model.layer5_i-matmul-83685"
ops: "model.layer5_i-add_n-83686"
ops: "model.layer5_i-broadcast_mul-83687"
ops: "model.layer5_i-broadcast_mul-83688"
ops: "model.layer5_i-reshape_like-83689"
ops: "model.layer5_i-sigmoid_grad-83690"
ops: "model.layer5_i-tanh_grad-83691"
ops: "model.layer5_i-add_n-83692"
ops: "model.layer5_i-narrow_grad-83693"
ops: "model.layer5_i-add_n-83694"
ops: "model.layer5_i-narrow_grad-83695"
ops: "model.layer5_i-broadcast_mul-83696"
ops: "model.layer5_i-broadcast_mul-83697"
ops: "model.layer5_i-broadcast_mul-83698"
ops: "model.layer5_i-broadcast_mul-83699"
ops: "model.layer5_i-sigmoid_grad-83700"
ops: "model.layer5_i-sigmoid_grad-83701"
ops: "model.layer5_i-tanh_grad-83702"
ops: "model.layer5_i-narrow_grad-83703"
ops: "model.layer5_i-add_n-83704"
ops: "model.layer5_i-narrow_grad-83705"
ops: "model.layer5_i-add_n-83706"
ops: "model.layer5_i-narrow_grad-83707"
ops: "model.layer5_i-add_n-83708"
ops: "model.layer5_i-matmul-83709"
ops: "model.layer5_i-matmul-83710"
ops: "model.layer5_i-add_n-83711"
ops: "model.layer5_i-add_n-83712"
ops: "model.layer5_i-matmul-83713"
ops: "model.layer5_i-matmul-83714"
ops: "model.layer5_i-add_n-83715"
...
ops: "model.layer5_i-matmul-92295"
ops: "model.layer5_i-add_n-92296"
ops: "model.layer5_i-reshape_like-92297"
ops: "model.layer5_i-narrow_grad-92298"
ops: "model.layer5_i-add_n-92299"
}
}
minimal test script:
input=16
output=8
seq_len=150 # compiling time is O(seq_len)
batch_size=100
import oneflow as flow
from oneflow import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.lstm1=nn.LSTM(input,output,bias=False)
def initial(self,batch_size=1):
return [flow.zeros((1,batch_size,output)),flow.zeros((1,batch_size,output))]
def forward(self, x, status):
return self.lstm1(x,status)
class Graph(nn.Graph):
def __init__(self,model,optimizer):
super().__init__()
self.model = model
self.add_optimizer(optimizer)
def build(self, y,x,states):
pred,states=self.model(x,states)
loss = (pred-y)**2
loss.mean().backward()
return loss
net=Model()
x=flow.zeros((seq_len,batch_size,input))
y=flow.zeros((seq_len,batch_size,output))
optimizer = flow.optim.SGD(net.parameters(), lr=1e-3, momentum=0.99, weight_decay=1e-5, nesterov=True)
graph=Graph(net,optimizer)
from time import time
try:
t=time()
loss=graph(y,x,net.initial(batch_size))
finally:
print(time()-t)
result with seq_len=150
Log file created at: 2023/04/25 12:08:25
Running on machine: 3060
Running duration (h:mm:ss): 0:00:00
Log line format: [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
I20230425 12:08:25.886257 27815 env_global_objects_scope.cpp:165] Using rpc backend: local
I20230425 12:08:26.076486 27815 epoll_comm_network.cpp:63] CommNet:Epoll listening on 0.0.0.0:38649
I20230425 12:08:26.256600 27815 version.cpp:22] OneFlow git version: 99b71264
I20230425 12:08:26.256635 27815 cuda_device_manager_factory.cpp:63] CUDA runtime version: 11.8
I20230425 12:08:26.256659 27815 cuda_device_manager_factory.cpp:72] cuDNN version: 8.9.0
I20230425 12:08:26.256662 27815 cuda_device_manager_factory.cpp:85] NCCL version: 2.15.1
I20230425 12:08:37.071957 27815 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 OptimizationLogicalGraph","mem_rss":"856.000000 MB","time_cost":"7 seconds"}
I20230425 12:08:38.973798 27815 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 AlignStates","mem_rss":"957.000000 MB","time_cost":"1 seconds"}
I20230425 12:08:47.366788 27815 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 CompleteJob","mem_rss":"974.000000 MB","time_cost":"8 seconds"}
I20230425 12:08:50.684219 27815 plan_util.cpp:1132]
Graph name Graph_0 in Rank: -1, Device: -1 needs to allocate [ 0 MiB ] device memory.
In general, Chunk id: -1 memory is [ -1e-06 MiB ] with mem_block_num = 0
Unreused memory not eager var is [ 0 MiB ] with mem_block_num = 0
Eager Variable Tensor total memory is [ 0 MiB ] with mem_block_num = 0
I20230425 12:08:50.797312 27815 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 CompilePlan","mem_rss":"1128.000000 MB","time_cost":"3 seconds"}
I20230425 12:08:50.797963 27815 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 SyncPlan","mem_rss":"1128.000000 MB","time_cost":"0 seconds"}
I20230425 12:08:51.753986 27815 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 InitRuntime","mem_rss":"1541.000000 MB","time_cost":"0 seconds"}
I20230425 12:09:08.308854 27815 env.cpp:126] forced eviction num: 0
I20230425 12:09:08.308881 27815 env.cpp:127] eager eviction num: 0
I20230425 12:09:08.308884 27815 env.cpp:128] recomputation num: 0
I20230425 12:09:08.308887 27815 env.cpp:129] duration: 0
result with seq_len=300
Log file created at: 2023/04/25 12:09:23
Running on machine: 3060
Running duration (h:mm:ss): 0:00:00
Log line format: [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
I20230425 12:09:23.052439 27975 env_global_objects_scope.cpp:165] Using rpc backend: local
I20230425 12:09:23.247009 27975 epoll_comm_network.cpp:63] CommNet:Epoll listening on 0.0.0.0:34267
I20230425 12:09:23.437595 27975 version.cpp:22] OneFlow git version: 99b71264
I20230425 12:09:23.437631 27975 cuda_device_manager_factory.cpp:63] CUDA runtime version: 11.8
I20230425 12:09:23.437652 27975 cuda_device_manager_factory.cpp:72] cuDNN version: 8.9.0
I20230425 12:09:23.437656 27975 cuda_device_manager_factory.cpp:85] NCCL version: 2.15.1
I20230425 12:09:44.322456 27975 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 OptimizationLogicalGraph","mem_rss":"1072.000000 MB","time_cost":"17 seconds"}
I20230425 12:09:48.301986 27975 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 AlignStates","mem_rss":"1265.000000 MB","time_cost":"2 seconds"}
I20230425 12:10:06.281481 27975 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 CompleteJob","mem_rss":"1298.000000 MB","time_cost":"17 seconds"}
I20230425 12:10:13.410984 27975 plan_util.cpp:1132]
Graph name Graph_0 in Rank: -1, Device: -1 needs to allocate [ 0 MiB ] device memory.
In general, Chunk id: -1 memory is [ -1e-06 MiB ] with mem_block_num = 0
Unreused memory not eager var is [ 0 MiB ] with mem_block_num = 0
Eager Variable Tensor total memory is [ 0 MiB ] with mem_block_num = 0
I20230425 12:10:13.691105 27975 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 CompilePlan","mem_rss":"1582.000000 MB","time_cost":"6 seconds"}
I20230425 12:10:13.692200 27975 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 SyncPlan","mem_rss":"1582.000000 MB","time_cost":"0 seconds"}
I20230425 12:10:15.646508 27975 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 InitRuntime","mem_rss":"2409.000000 MB","time_cost":"1 seconds"}
system:
$ uname -a
Linux 3060 6.2.12-1-MANJARO #1 SMP PREEMPT_DYNAMIC Thu Apr 20 14:17:37 UTC 2023 x86_64 GNU/Linux
$ python -c "import oneflow as flow;print(flow.__version__)"
libibverbs not available, ibv_fork_init skipped
0.9.1.dev20230419+cu118
@strint
It seems that, it is time consuming compiling a rnn with long sequence length.
Is there any plan to fix such disappoint behavior?
I20230425 00:48:47.266355 20443 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 OptimizationLogicalGraph","mem_rss":"11621.000000 MB","time_cost":"433 seconds"}
I20230425 00:49:50.931952 20443 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 AlignStates","mem_rss":"11623.000000 MB","time_cost":"47 seconds"}
I20230425 00:54:51.325140 20443 cost_util.h:98] [count log]{"loc":"[GraphCompile]Graph_0 CompleteJob","mem_rss":"11606.000000 MB","time_cost":"300 seconds"}
It seems that logical graph compilation did take a lot of time.
It reminds me of this issue: https://github.com/Oneflow-Inc/oneflow/issues/9286. I will try to refine the op graph init cost.