onnx-tensorflow icon indicating copy to clipboard operation
onnx-tensorflow copied to clipboard

GRU support with linear_before_reset

Open iiSeymour opened this issue 5 years ago • 10 comments

Similar to #227 but without adding the backend, uses the CudnnCompatibleGRUCell implementation in TensorFlow instead.

iiSeymour avatar Jul 03 '19 14:07 iiSeymour

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Jul 24 '19 00:07 CLAassistant

Looks like https://www.tensorflow.org/api_docs/python/tf/contrib/cudnn_rnn/CudnnCompatibleGRUCell is deprecated. In general, we should avoid using APIs in contrib. We would need to replace it soon anyway so could you find alternative ways?

chinhuang007 avatar Sep 06 '19 20:09 chinhuang007

It looks like the new GRU layer is CuDNN compatible and supports reset_after so I think it should just be a case of swapping tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell for tf.keras.layers.GRU. I will have a proper look tomorrow.

iiSeymour avatar Sep 06 '19 20:09 iiSeymour

Needs #460 fixing first.

iiSeymour avatar Sep 09 '19 12:09 iiSeymour

Yeah, we will need to revisit this once we get to TF 2.0

chinhuang007 avatar Sep 12 '19 17:09 chinhuang007

Hi, does this feature still on progress? I met this problem when I transfer my ONNX model. I'd like to fix this, if there's no implement yet

dianyo avatar Jul 31 '20 09:07 dianyo

Hi, does this feature still on progress? I met this problem when I transfer my ONNX model. I'd like to fix this, if there's no implement yet

Any update on this? I am also facing the same issue while converting ONNX to tensorflow.

alamnasim avatar Oct 20 '20 11:10 alamnasim

The patch I submitted worked for me at time submission but I don't have time to (re)evaluate currently @alamnasim @dianyo give it a try.

iiSeymour avatar Oct 20 '20 11:10 iiSeymour

Any update on this?

tylerweitzman avatar Jul 11 '21 07:07 tylerweitzman

@iiSeymour, when I use your gru-reset-after branch,I found the value of the argument 'name' in function _custom_getter() which defined in gru.py is 'GRU_9c8f1236/rnn/multi_rnn_cell/cell_0/cudnn_compatible_gru_cell/candidate/hidden_projection/bias'. Can you tell me how is the argement name generated, I have not found it in the context, thanks.

TingfengTang avatar Dec 20 '21 09:12 TingfengTang