Mahmoud Asem
Mahmoud Asem
Hi @mgoulao I made a JAX [package](https://github.com/ASEM000/kernex) for stencil computation that can be used to calculate the gaussian blur. I checked the performance of `kernex` vs `dm_pix` on [Colab CPU](https://github.com/ASEM000/kernex)...
Noted thanks, I will try to contribute in the coming days. Best.
Hey, I implemented `dm_pix.gaussian_blur` with no extra dependencies in this [colab](https://colab.research.google.com/drive/119tQIZPmq-XznFgSavtb7e13HfIU70Ju?usp=sharing) you can find testing and benchmarking against the depthwise-based implementation. On colab CPU I'm getting the following speed up...
Yes, you are right; sorry for the typo.
Hello Adam, thank you for your question. For background, there are a couple of libraries with similar ideas (Pytorch-like API) that predate `Treex` and `Equinox` and as you have seen,...
~~Except `flax.struct`, I think most Pytree libraries should behave similarly regarding memory/speed. `PytreeClass` is slightly faster because no logic (for static fields) is done when flattening/unflattening.~~ Check readme for benchmark...
For reference: [1] Pytree-based implementation : one that predates `equinox`/`treex` [flax PyTreeNode](https://github.com/google/flax/commit/291a5f65549cf4522f0de033451cd83c0d0168d9), another one that postdate it [pax](https://github.com/NTT123/pax) [2] equinox `tree_at` [sample issue](https://github.com/patrick-kidger/equinox/issues/194) [3] filter inconsistent behavior-sample issues [1](https://github.com/patrick-kidger/equinox/issues/189), [2](https://github.com/patrick-kidger/equinox/issues/318)
As of version 0.8.0 TLDR; as of version `0.8.0` use ```python import pytreeclass as pytc import jax @pytc.autoinit class Tree(pytc.TreeClass): frozen_a: int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze]) def __call__(self, x): return self.frozen_a...
For a deeply nested instance with frozen attributes all over the place, you need to write it once (usually inside your loss function) , something like this. ```python from typing...
You are right; fortunately, it's easy to do just that. ```python def unfreeze_func(func): @ft.wraps(func) def wrapper(tree, *a, **k): tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen) return func(tree, *a, **k) return wrapper @jax.jit...