tensor2tensor icon indicating copy to clipboard operation
tensor2tensor copied to clipboard

RFC: What do you think about TRAX?

Open lukaszkaiser opened this issue 5 years ago • 12 comments

We're thinking how to make the next T2T much better. One thing that came up is using JAX and gin config and we've prototyped TRAX: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/trax

If you're interested, please take a look. Run one of the examples, try to change something, train on your data, make your own model, tweak things. If you had trouble doing things in T2T before, let us know if that looks like it'd help!

TRAX is very early and it lacks features and has bug - we know that and we'll be correcting the small things as we go. But we'd love to think about higher-level things that may be easier to address at this stage, before the design stabilizes. Especially if you had trouble doing things in T2T before, let us know if that looks like it'd help!

lukaszkaiser avatar Mar 07 '19 02:03 lukaszkaiser

Hello Lukasz,

Current problem #1 - TF1.0 I simply do not like the graph API of TF1, and I think any move away from it is a good thing, so TF2.0-T2T and JAX-T2T should both do the trick. As T2T is pretty complex and a lot of questions need to be answered by going through the code, the hard-to-follow graph execution did not really help.

Current problem #2 - Estimators Being able to just load and use a T2T model within an application, without using a complex tf-serving approach, was also really hard to do. The Estimator-Interface is the main problem here, as it is reloading the whole model for each predict call. I really hope TF2.0-T2T and JAX-T2T will allow to work around Estimators more seamlessly.

TRAX - What I like Well, the whole code looks really clean and easy to follow. In theory, it should solve all my issues with T2T

TRAX - What I dislike I do not like the idea of using yet another framework for T2T. While JAX seems cool and gains traction, wouldn't it be a better idea to base the rewrite on TF2.0? I understand that a seamless transition from current T2T to a TF2.0 backed version is not possible. So rewriting big parts of the framework might be needed, perhaps even a start from scratch. But then, if you plan to rewrite a lot of the code anyway (for the TRAX version), why not bundle the resources, start from scratch and use an eager TF2.0 approach?

I think in the long term T2T will not benefit from multiple branches that need to be maintained separately. Especially if one of them is still heavily relying on TF1.0 and graph mode, and the other is relying on a quite small framework like JAX.

Thoughts?

f-lng avatar Mar 07 '19 14:03 f-lng

I've used Autograd in the past for some research, and found it to be really inefficient and ended up walking away from it to use pytorch, which had huge speed ups.

As a more concrete example, at one point I encountered the backwards pass of autograd consuming 20gb of memory, which ended up being fixed with PR.

I mention the above to add support to @f-lng's comment, that adding another framework for t2t would make maintaining t2t harder, since bugs would inevitably arise that stem from the underlying framework.

OTOH, I think it's important to note that T2T's mission in part is to make deep learning more accessible. I'm not familiar enough with TF2.0 to understand if migrating to it would help this goal. What are your thoughts on that @lukaszkaiser / @f-lng?

etragas-fathom avatar Mar 08 '19 21:03 etragas-fathom

@etragas-fathom I do think that TF2.0 would serve the mission of making it more accessible, because model execution / codeflow is simplified with the eager programming style and model creation is simplified with the keras API.

I might be biased here, as I only touch TF if I can not get around it (e.g. for T2T), and always found Keras to be a pretty good alternative, especially if you are doing engineering, not research.

I think a proper TF2.0 based T2T would be perfect, but if TRAX is the way to go, it might be a better approach to simply drop the TF backend alltogether and focus on the TRAX version.

f-lng avatar Mar 09 '19 11:03 f-lng

The problem with TF 2.0 at least for now is that when you want speed (use @tf.function or functional Keras mode) you're back in TF 1.0 graph-mode land. With shape bugs, large stack traces and all, and it feels as hard to debug as TF 1.0 or harder.

With JAX the speed problem of autograd is gone (I'm just training a Transformer LM as fast as T2T and a Resnet just a little slower). But other bugs may re-surface with more use, we'll need to see, I guess.

Please keep adding comments so we know what to look out for!

lukaszkaiser avatar Mar 22 '19 02:03 lukaszkaiser

Well, if you guys say TF 2.0 is not a good fit for T2T (yet?), then I guess TRAX is a good alternative, as the code looks very clean and the rewrite is easy to follow :-)

Btw, is there already a beam search decoding implemented and documented? I would love to give it a try.

f-lng avatar Mar 22 '19 12:03 f-lng

Perhaps not a hugely insightful comment, it's a shame that this makes installing recent versions of t2t non-trivial under Windows (see https://github.com/tensorflow/tensor2tensor/issues/1507).

JosephRedfern avatar Mar 28 '19 18:03 JosephRedfern

I would like TRAX to be independent a repo/package. I don't want to install tensorflow if possible, but T2T depends on it.

moskomule avatar May 09 '19 02:05 moskomule

As on 6c7c601b8c4429dcc81ab3ec828daddea5ff2b67 it seems like TRAX has been moved to it's own repo ~although it's not clear to which one though~ https://github.com/google/trax

bzz avatar Oct 28 '19 14:10 bzz

https://github.com/google/trax

On Mon, Oct 28, 2019, 7:51 AM Alexander [email protected] wrote:

As on 6c7c601 https://github.com/tensorflow/tensor2tensor/commit/6c7c601b8c4429dcc81ab3ec828daddea5ff2b67 it seems like TRAX has been moved to it's own repo, but although it's not clear to which one though.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/tensor2tensor/issues/1478?email_source=notifications&email_token=AAIUEFSJPBZUFXW3WCH33A3QQ34ARA5CNFSM4G4IZMA2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOECNEX4Y#issuecomment-546982899, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAIUEFSKAIECAUVAWPOHQXTQQ34ARANCNFSM4G4IZMAQ .

afrozenator avatar Oct 28 '19 19:10 afrozenator

@afrozenator @lukaszkaiser

Could you tell me the advantage of JAX over PyTorch? Does JAX provide any tool with which Transformer and Reformer becomes faster for the reason other than the fact that JAX probably has better support of TPU?

AranKomat avatar Dec 24 '19 18:12 AranKomat

@lukaszkaiser I am not sure whether or not I am allowed to post it here. I feel what you said is totally correct. To me, Keras and TF are really complicated and a mess in their Graph API and more. I hope Trax will be as readable as Pytorch on building a more complicated model.

Here is an example of Pytorch on Unet, which is readable and beautiful. I am no researcher or engineer. I am just nobody, but I am able to build GAN. I am also able to build a self-supervised learning model to color my favorite Japanese Anime picture without having GAN's unstable problem. But I can only do it in Pytorch because it is so easy, readable, and beautiful. It is so friendly to open source people.

https://gist.github.com/Hsankesara/e3b064ff47d538052e059084b8d4df9f#file-unet-py image

JonathanSum avatar Aug 11 '20 10:08 JonathanSum

@lukaszkaiser Trax takes extremely long time to be imported, which makes it very uncomfortable to debug.

lkluo avatar Feb 03 '21 04:02 lkluo