Jianing Ye
Jianing Ye
Same issue here, any updates on this issue?
### script This is a minimal reproducible script: https://colab.research.google.com/drive/1cq2C8VTPUja--PvbKq21M8ULglME2vrU?usp=sharing The output shows ``` trained_steps: 207 if raw_batch is finite: False raw_batch: (Array(True, dtype=bool), Array(True, dtype=bool), Array(False, dtype=bool)) ``` which means...
Thanks for the quick update. If the goal is to align with the original sampling approach — i.e., sampling a query value $x \sim \text{Unif}[0, R)$ and traversing the tree...
@cgarciae OK, I have reported this to the [JAX repo](https://github.com/jax-ml/jax/issues/27228).
@sash-a Thanks for your reply. I think I am using a correct donation since I am jitting the wrapped function `buf_add` but not the `buffer.add`. `buf_add` has only one argument.
And I have removed the wrapper. The result is similar. ```Python import os # os.environ["JAX_PLATFORM_NAME"] = "cpu" from types import SimpleNamespace import jax import jax.numpy as jnp import timeit import...
@EdanToledo Thank you for the clarification! I have a few more questions. BTW, What is the main factor causing the speed difference of sum tree updates between CPU and GPU?...
@EdanToledo Thank you so much for the detailed explanation! That answers some of my questions. However, I think the CPU `add` code might still be buggy here according to my...
Hi there, any update on this issue?