Jesse Grabowski

Results 168 comments of Jesse Grabowski

I'm not sure if that environment.yml is used for much, it looks like docs only from a quick search? I would support getting rid of it entirely.

PyMC puts them in their own folder [here](https://github.com/pymc-devs/pymc/tree/main/conda-envs); the PR was this way at my recommendation. I guess there's no reason why they couldn't be in the .github folder instead,...

This works: ```python def step(x, epsilon, mu, sigma): next_x = x + mu + sigma * epsilon return next_x rng = pytensor.shared(np.random.default_rng()) new_rng, epsilons = pt.random.normal(size=10, rng=rng) traj, updates =...

I guess the tags I chose for this issue are quite bad because I don't think I want any kind of special automatic handling here. More that it seems like...

```python mu = pt.dscalar('mu') sigma = pt.dscalar('sigma') x0 = pt.dscalar('x0') rng = pytensor.shared(np.random.default_rng(), 'rng') def step(x, mu, sigma, rng): new_rng, epsilon = pm.Normal.dist(0, 1, rng=rng).owner.outputs next_x = x + mu...

First refactor all of pytensor (and pymc) to remove shared variables? :D

I just read https://github.com/pymc-devs/pytensor/issues/473, so your thoughts on `OpFromGraph` and shared variables are fresh in my mind.

I think I was wrong though, it's casting to `int8` not `uint8` (int8 goes to 128. uint8 should go to 255). Wouldn't dropping the minus sign break negative indexes?

I guess I'm asking why we rewrite to uint8 then? Isn't it needlessly restrictive?

Maybe rewrite to something like this? ``` import jax import jax.numpy as jnp def prod(x): return jnp.exp(jnp.sum(jnp.log(x))) # @jax.jit def foo(x): return jax.grad(prod)(x) jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32")) ``` Amusingly, this was suggested...