Jianing Ye

Results 9 comments of Jianing Ye

Same issue here, any updates on this issue?

### script This is a minimal reproducible script: https://colab.research.google.com/drive/1cq2C8VTPUja--PvbKq21M8ULglME2vrU?usp=sharing The output shows ``` trained_steps: 207 if raw_batch is finite: False raw_batch: (Array(True, dtype=bool), Array(True, dtype=bool), Array(False, dtype=bool)) ``` which means...

Thanks for the quick update. If the goal is to align with the original sampling approach — i.e., sampling a query value $x \sim \text{Unif}[0, R)$ and traversing the tree...

@cgarciae OK, I have reported this to the [JAX repo](https://github.com/jax-ml/jax/issues/27228).

@sash-a Thanks for your reply. I think I am using a correct donation since I am jitting the wrapped function `buf_add` but not the `buffer.add`. `buf_add` has only one argument.

And I have removed the wrapper. The result is similar. ```Python import os # os.environ["JAX_PLATFORM_NAME"] = "cpu" from types import SimpleNamespace import jax import jax.numpy as jnp import timeit import...

@EdanToledo Thank you for the clarification! I have a few more questions. BTW, What is the main factor causing the speed difference of sum tree updates between CPU and GPU?...

@EdanToledo Thank you so much for the detailed explanation! That answers some of my questions. However, I think the CPU `add` code might still be buggy here according to my...