Matthew Johnson
Matthew Johnson
Should we make pypolyagamma optional, instead of listing it in the `install_requires` argument to `setup` in setup.py? It might make the build/install process simpler, though users who want to do...
Currently the code only works with the `working-factorial` branch of pyhsmm, which is a bit old.
The concentration parameter resampling strategy is suboptimal and possibly inaccurate. Try replacing it with either slice sampling or discretization.
Based on some [autograd code](https://github.com/HIPS/autograd/pull/175#issuecomment-306524258) from @j-towns, @alextp and I recently had this idea for getting forward-mode JVPs in TensorFlow by calling `tf.gradients` twice: ```python import numpy as np import...
When an inner jit simply forwards some of its inputs to outputs, we can prune those outputs and use the caller's value for them. ```python import jax @jax.jit def f(x):...
```python import jax import jax.numpy as jnp print(jax.make_jaxpr(lambda x: jnp.sin(jnp.sin(x)) + 2)(3)) ``` Before: ``` { lambda ; a:i32[]. let b:f32[] = convert_element_type[new_dtype=float32 weak_type=True] a c:f32[] = sin b d:f32[]...
Come on: ``` In [16]: jax.random.normal(jax.random.key(0), 1000) --------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[16], line 1 ----> 1 jax.random.normal(jax.random.key(0), 1000) File ~/packages/jax/jax/_src/random.py:710, in normal(key, shape, dtype) 707 raise...
This way we don't get lots of spurious failures when running things like `pytest tests`.
To make `inline_jaxpr_into_trace` work with dynamic shapes, we need to perform variable substitution into the types of the jaxpr eqns being inlined. We also need to substitute in the appropriate...