equinox
equinox copied to clipboard
Storing CoLA Linear Operators as a Field in an Module
The following Linear Operator Library has recently been released, https://cola.readthedocs.io/en/latest/index.html, which allows for lazy evaluation in the following manner,
lazy_A = cola.fns.lazify(A)
where the type of lazy_A will be a cola.ops.LinearOperator. Overall the library is incredibly powerful and I only foresee it becoming a core tool within the JAX ecosystem.
I was wondering if it would be possible to store this as a field in an equinox Module, as the following MVE produces an error when attempting to call the filter functionality,
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.random as jr
import optax
import equinox as eqx
import cola
class MyModule(eqx.Module):
lazy_A : cola.ops.LinearOperator
def __init__(self, A):
self.lazy_A = cola.fns.lazify(A)
def __call__(self, x):
return self.lazy_A @ x
seed = jr.PRNGKey(0)
A = jr.normal(seed, (10, 10))
X = jnp.ones((10, 1))
model = MyModule(A)
result = eqx.filter(model, eqx.is_inexact_array)
File "/media/adam/shared_drive/PycharmProjects/test_equinox_lazy_variable/test_equinox_lazy_variable.py", line 24, in <module>
result = eqx.filter(model, eqx.is_inexact_array)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 129, in filter
filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 71, in _filter_tree
return jtu.tree_map(mask, arg, is_leaf=is_leaf)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 245, in tree_unflatten
return cls(*new_args, **aux[0])
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 21, in __init__
super().__init__(dtype=A.dtype, shape=A.shape)
AttributeError: 'bool' object has no attribute 'dtype'
The problem is that the "type" of the lazy variable is, <10x10 Dense with dtype=float32>
I was wondering if there is an easy way around this, such that it makes using equinox compatible with cola lazified variables?
Hey there. This looks like a bug in CoLA. The issue is that they're using __init__
inside the unflatten rule of their pytrees: https://github.com/wilson-labs/cola/blob/50d61041b94b0f092e89cc05169538b6e9caeeba/cola/ops/operator_base.py#L250
This is a common mistake when implementing custom pytrees in JAX; see the documentation here: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
I'd recommend reporting the bug to CoLA.
They could fix this by either (a) having their LinearOperator
inherit from eqx.Module
, or (b) unflattening via __new__
and __setattr__
; see the Equinox implementation here for reference.
Incidentally, it's interesting to see another JAX-compatible library for linear operators. We also have Lineax, which is a part of the JAx+Equinox ecosystem, and provides both linear operators and linear solves. It looks like the two libaries have substantially overlapping (but not identical) functionality. For example Lineax supports pseudoinverses, but CoLA supports eigendecompositions.
Thanks for the fast response. I will report it on their github.