java icon indicating copy to clipboard operation
java copied to clipboard

Map experimental C (actually C++) API for gradient tape

Open saudet opened this issue 4 years ago • 12 comments
trafficstars

I've only tested the build on Linux for now, but it should also work on Mac and Windows.

Note that the API has already changed with 2.5.0, so we should probably upgrade to that version before looking at this too closely.

saudet avatar Apr 09 '21 16:04 saudet

@saudet , is this PR still just a draft or you think it is ready to be reviewed and merged?

karllessard avatar Apr 24 '21 19:04 karllessard

Since it doesn't look like we're going to do anything for this with TF 2.4.x, I think I'll upgrade this PR to 2.5.0-rc1 and then we can merge after its release? I don't think it makes sense to start doing something with the API for 2.4.x.

saudet avatar Apr 25 '21 13:04 saudet

I've finally rebased this on master and upgraded for TF 2.5.0! I've also undone the unreadable reformatting of presets/tensorflow.java, but feel free to redo if necessary. I'd still consider this a WIP, but if it doesn't break any builds, it should be fine to merge and start getting people playing with it, as long as we're ready to maintain an unstable experimental API....

saudet avatar Jun 25 '21 03:06 saudet

Hey, is this still being worked on @saudet ? This is the missing ingredient for me to implement a RL model in KotlinDL. If it is not being worked on anymore, what is left to do? Maybe I could help with it.

dosier avatar Nov 16 '21 06:11 dosier

As far as I know we're waiting on full support in Tensorflow. The RFC is here and there seems to be work happening gradually, but it's not to the point where we can use it as a full solution for gradients, I don't think. The actual gradients are here, and as you can see there's quite a few missing, and there's currently no registration method or anything like that.

rnett avatar Nov 16 '21 06:11 rnett

@dosier FWIW, you may be better off with PyTorch. Its C++ API is apparently rich enough to get anything RL going: https://github.com/Omegastick/pytorch-cpp-rl https://github.com/navneet-nmk/Pytorch-RL-CPP https://github.com/mhubii/ppo_libtorch

And the JavaCPP Presets for PyTorch provides full access to that C++ API: https://github.com/bytedeco/javacpp-presets/tree/master/pytorch Please do let me know if there's anything missing though!

saudet avatar Nov 16 '21 06:11 saudet

@dosier FWIW, you may be better off with PyTorch. Its C++ API is apparently rich enough to get anything RL going: https://github.com/Omegastick/pytorch-cpp-rl https://github.com/navneet-nmk/Pytorch-RL-CPP https://github.com/mhubii/ppo_libtorch

And the JavaCPP Presets for PyTorch provides full access to that C++ API: https://github.com/bytedeco/javacpp-presets/tree/master/pytorch Please do let me know if there's anything missing though!

Cheers! Been hoping to see the GradientTape integrated into the TF Java API for a while, mainly so that I can contribute RL stuff in KotlinDL :D. But I also need to implement a RL model for my studies this block so the PyTorch java wrapper is a pleasant surprise (I love my static types too much).

dosier avatar Nov 16 '21 07:11 dosier

Cheers! Been hoping to see the GradientTape integrated into the TF Java API for a while, mainly so that I can contribute RL stuff in KotlinDL :D. But I also need to implement a RL model for my studies this block so the PyTorch java wrapper is a pleasant surprise (I love my static types too much).

BTW, it looks like the author of KotlinDL would be open to integrating PyTorch as well, see https://github.com/pytorch/pytorch/issues/58973#issuecomment-855191456. However, I'm guessing he would like to get Facebook and/or Microsoft to cooperate a bit before doing anything with it.

/cc @zaleslaw

saudet avatar Nov 16 '21 07:11 saudet

Yeah, there are a few ways to integrate Torch in Kotlin

  1. JNI
  2. JavaCPP
  3. Pure PyTorch Java API with IValue for inference only

Hope that in 2022 KotlinDL will be able to support the training of Torch models via JNI (or via JavaCPP)

Good luck, @dosier with your experiments with RL and hope to see you in the future with the running RL models

zaleslaw avatar Nov 16 '21 09:11 zaleslaw

@zaleslaw May I ask why you're considering writing JNI code manually? What's missing from JavaCPP?

saudet avatar Nov 16 '21 09:11 saudet

As far as I know we're waiting on full support in Tensorflow. The RFC is here and there seems to be work happening gradually, but it's not to the point where we can use it as a full solution for gradients, I don't think. The actual gradients are here, and as you can see there's quite a few missing, and there's currently no registration method or anything like that.

@rnett , you are saying that with the custom gradient supported you've added not long ago plus this new internal API, we are still not able to register our own gradients in eager mode?

karllessard avatar Nov 17 '21 14:11 karllessard

As far as I know we're waiting on full support in Tensorflow. The RFC is here and there seems to be work happening gradually, but it's not to the point where we can use it as a full solution for gradients, I don't think. The actual gradients are here, and as you can see there's quite a few missing, and there's currently no registration method or anything like that.

@rnett , you are saying that with the custom gradient supported you've added not long ago plus this new internal API, we are still not able to register our own gradients in eager mode?

I'm saying that for this new API, there's no built-in registries (i.e. global, or Graph/EagerSession based), so we would have to create and manage our own. Once we do that, it would be easy enough to add custom Java-side gradients.

I haven't seen confirmation anywhere that some sort of registry and auto-registration is planed, but I would expect it. I'm not sure how python does it, or if it's using this setup at all.

rnett avatar Nov 17 '21 17:11 rnett