darts-tensorflow
darts-tensorflow copied to clipboard
The tf.control_dependencies seems extremely time consuming
试着运行了好几次,感觉很卡啊。每次执行到tf.control_dependencies的时候就会卡很久,最长的能卡15分钟左右。求问是什么原因啊?是因为网络太大了吗?我看tf.control_dependencies那里,变量列表长度是2300多,变量太多导致的嘛?
网络构建完以后,要送入GPU执行的时候,也是卡住,十几分钟了还没开始运行。
有什么办法能加速一下嘛?感谢大佬指点!
I run the train_search.py for multiple times. The program always hang at the tf.control_dependencies in the compute_unrolled_step() function. It just stops here and printing nothing. Specifically, it hangs at the line
unrolled_optimizer=unrolled_optimizer.minimize(unrolled_train_loss,var_list=unrolled_w_var)
It hangs for around 15 minutes until it goes on.
It there any way to speed it up? I'm using a server with Tesla V100. I don't think it's because my server is too slow.
Thanks!
可能是因为每次梯度更新的时候,都要给unroll_model赋值再计算一遍梯度,整个图就会特别复杂。当时实现的时候也没有考虑效率的问题,可能实现的逻辑上还有很多优化的空间。因为最近比较忙,也没时间看这个问题了,如果实在要用的话可以先用着原版pytorch版的吧,抱歉了
@NeroLoh 感谢大佬,那能不能不用这个tf.control_dependencies呢?会对结果有影响吗?
这个主要是为了计算梯度签,执行给unroll_model赋值的操作https://github.com/NeroLoh/darts-tensorflow/blob/05e6228af144d7d09400a42e21af3e25c1ded862/cnn/train_search.py#L137 应该是不能跳过的