Stergios Bachoumas
Stergios Bachoumas
Ok I will add some tests for the `Probs` versions of the distributions and submit a PR for the discrete distributions and you can review it there. Thanks for the...
Hey @yardenas I am glad you have already done some steps into this, cause frankly time is of the essence. I took a brief look at your imlementation and its...
So S4 layers are made to handle a one-dimensional sequence of size L and produce a one-dimensional output. This means that if you need to pass multiple "parallel" sequences you...
@yardenas, I saw that was the case. I am trying to write is as Patrick says. Works for the 1D sequence. Will test today for the H-Dimensional sequence.
@yardenas In your code, the forward pass of the model in RNN form is the following: ```python @jax.vmap def __call__(self, x_k_1, u_k, ssm): ab, bb, cb = ssm if u_k.ndim...
```python class S4Layer(ex.Module): log_step: Array Lambda_reArray Lambda_im: Array P: Array B: Array C: Array D: Array cell_size: int = eqx.field(static=True) #N sequence_size: int = eqx.field(static=True) #L ssm: Tuple = ()...
Hey @lockwo. Thanks for the comment. I have seen the PPO algo that Patrick wrote in the past as well as other implementations like CleanRL. I reported this as I...
> One possible difference between the implementations may be that Equinox and Flax use different initialisations for the same seed (e.g. sampling from a random vs a uniform distribution). I...
Tried with the model for single observation and vmap outside. I get the same performance as before (I mean **exactly** the same to the last decimal point). The model is...
> That is what I meant. It was mostly a design pattern note, I wouldn’t expect any numerical differences Yeah I was just paranoid and checked everything. I managed to...