raj-brown

Results 15 comments of raj-brown

Hi @mattjj , do you have the for ```memory_ratio``` as well, and can you please share that? Thanks!

Thank @mattjj for pasting it here. I just had a question what module is ```pe``` in ```pe.JaxprEqnRecipe```? Thanks!

Thank you @patrick-kidger . I had another question. I want to have a `pytree` copy of the neural neural network as I have to change the value of parameters in...

Thank you very much @patrick-kidger. I really appreciate. A big thank you for creating Equinox. It is awesome. Thanks!

Hi @patrick-kidger I had another question, I have to use `jvp` with primal being neural network parameters and tangent being of same type and as shape of neural network parameters...

Hi @patrick-kidger any suggestion or help will be great. Thank you!

hi @patrick-kidger Sure. I will prepare one and put it here. On that note, do you have any suggestion to know how much memory a jetted function is using? Some...

Hi @patrick-kidger I want to take the jvp of loss function with respect to nn parameters along the random direction..Here is my code to do that ``` @eqx.filter_jit def train_step_fwg(network,...

@patrick-kidger Hi Patrick any help on this issue will be really helpful..Thank you so much!

Thanks @patrick-kidger. In fact I split them this my for loop in driver script. ``` key = jax.random.PRNGKey(7) sub_key = jr.split(key, N_EPOCHS) print(sub_key) sys.exit() key_count = 0 counter = tqdm(np.arange(N_EPOCHS))...