pytreeclass
pytreeclass copied to clipboard
Frozen / Static leaves
Reading the documentation, I understand that you can freeze variables by using a mask based upon name or type.
Is it possible to set a variable to "frozen" within the class definition i.e in the way Equinox has static_field option.
While I understand the concept behind being able to mask out a set of variables contained in a PyTree (or PyTree of PyTrees), there are lots of situations where you know when creating a new class, that certain variables will only ever be constant. Furthermore, as models become much more complicated (or if others may utilise elements of your model) it becomes more cumbersome to have to mask these out / others have to know to do this.
As of version 0.8.0
TLDR;
as of version 0.8.0
use
import pytreeclass as pytc
import jax
@pytc.autoinit
class Tree(pytc.TreeClass):
frozen_a: int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze])
def __call__(self, x):
return self.frozen_a + x
tree = Tree(frozen_a=1) # 1 is non-jaxtype
# can be used in jax transformations
@jax.jit
def f(tree, x):
return tree(x)
print(f(tree, 1.0)) # 2.0
print(jax.grad(f)(tree, 1.0)) # Tree(frozen_a=#1)
print(jax.tree_util.tree_leaves(tree)) # []
More details into about the freezing/unfreezing mechanism:
If you prefer manual masking, you could apply pytc.freeze
on the value directly. But you have to use is_leaf=pytc.is_frozen
if you want to interact with this value using tree_map
Using this style, the end user will only have to unmask before calling. At the same time, having access to the masked values using is_leaf=pytc.is_frozen
.
You can do something like this:
Style 1: with no init body, callbacks
here is a list of functions applied on your in_features
before setting it to the instance.
import pytreeclass as pytc
class Tree(pytc.TreeClass):
in_features: int = pytc.field(callbacks=[pytc.freeze])
Style 2: with init body
class Tree(pytc.TreeClass):
def __init__(self, in_features: int):
# Some logic using in_features
# ...
# Lastly you freeze it
self.in_features = pytc.freeze(in_features)
def __call__(self, x:float):
return x * self.in_features
t1 = Tree(2)
@jax.value_and_grad
def jax_func(tree:Tree):
tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)
return tree(1.0)
jax_func(t1)
# (2.0, Tree(in_features=#2)) # ->`#` is frozen marker
For background, an earlier version of pytreeclass had static field-like behaviour, but this has three problems:
1 . Even if these fields are constants, Using static_field
, you will lose the ability to filter your models based on that always-non-trainable field using jax.tree_map
.
2. .at
uses jax.tree_map
under the hood, if I let the user designate a permanently static field, then this will have an asymmetric design. For example, if you select a
as a static field for model nn
, then nn.a
will work while nn.at['a'].get()
will not work at all.
3. static_field
will lead to repetitive code because you have to declare it twice as a field and inside the init body. something like this: (from equinox conv code)
class Conv(Module):
"""General N-dimensional convolution."""
num_spatial_dims: int = static_field()
weight: Array
bias: Optional[Array]
in_channels: int = static_field()
out_channels: int = static_field()
kernel_size: Tuple[int, ...] = static_field()
stride: Tuple[int, ...] = static_field()
padding: Tuple[Tuple[int, int], ...] = static_field()
dilation: Tuple[int, ...] = static_field()
groups: int = static_field()
use_bias: bool = static_field()
def __init__(
self,
num_spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = 0,
dilation: Union[int, Sequence[int]] = 1,
groups: int = 1,
use_bias: bool = True,
*,
key: PRNGKey,
**kwargs,
):
This gets worse as you write more and more code.
Lastly, pytc.freeze
is just a pytree with no leaves yielded during the flattening rule. So you can use pytc.freeze on any pytree ( no special treatment inside a TreeClass
).
This design eliminates static field logic during the flattening/unflattening of a tree, leading to faster flattening/unflattening for non-masked trees and simplifying the code. Let me know if this answers your question.
So if I am understanding this correctly,
tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen
)
needs to be called prior to any call to a pytc.TreeClass containing static / frozen variables. So if I have a class within a class within a class all containing frozen variables (or a class that contains numerous other classes which utilise frozen variables), for each call to methods of that class they must have this unfreezing.
That doesn't seem ideal, when you start to get much more complicated models or wish to build a library of functions (as every class would need to be "wrapped" to hide this from the user).
For a deeply nested instance with frozen attributes all over the place, you need to write it once (usually inside your loss function) , something like this.
from typing import Any
import pytreeclass as pytc
import jax
class A(pytc.TreeClass):
a: int = pytc.freeze(1)
b: float = 2.0
def __call__(self, x):
return self.a * x + self.b
class B(pytc.TreeClass):
c: int = pytc.freeze(1)
d: A = A()
def __call__(self, x):
return self.c * x + self.d(x)
b = B()
# B(c=#1, d=A(a=#1, b=2))
@jax.jit
@jax.value_and_grad
def loss_func(b: B):
b = jax.tree_map(pytc.unfreeze, b, is_leaf=pytc.is_frozen)
return b(1.0)
loss_func(b)
# (Array(4., dtype=float32, weak_type=True),
# B(c=#1, d=A(a=#1, b=f32[](μ=1.00, σ=0.00, ∈[1.00,1.00]))))
For comparison, under the hood, equinox
filter decorated functions do something similar on two steps:
first equinox splits the tree to trainable/non-trainable parts before the Jax boundary, then combines it inside the jax function for each call. pytreeclass scheme should be faster because you only do one step.
import equinox as eqx
import jax
import pytreeclass as pytc
import jax.numpy as jnp
class TreeEqx(eqx.Module):
a:int = eqx.static_field(default=1)
b:jax.Array = jnp.array(1.)
class TreePyTC(pytc.TreeClass):
a:int = pytc.freeze(1)
b:jax.Array = jnp.array(1.)
tree = TreePyTC()
@jax.jit
def some_func(t):
t = jax.tree_map(pytc.unfreeze, t, is_leaf=pytc.is_frozen)
return t.a + t.b
%timeit some_func(tree)
# 12.1 µs ± 836 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
tree = TreeEqx()
@eqx.filter_jit
def some_func(t):
return t.a + t.b
%timeit some_func(tree)
# 26.7 µs ± 5.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
let me know if you have any questions.
Arh ok, I understand. That is obviously more manageable.
I wonder if adding a wrapper function/decorator to hide this from users might be useful? The equinox decorated functions are very useful in this respect of hiding the complexity away.
I can see occurrences where somebody might want to use your model and try a different loss function, or incorporate your model / NN into a pipeline of others and they don't realise this behaviour. The ability to wrap your model such that another user doesn't even need to think about this jax.tree_map(pytc.unfreeze, b, is_leaf=pytc.is_frozen)
might prove helpful in stopping obvious mistakes.
You are right; fortunately, it's easy to do just that.
def unfreeze_func(func):
@ft.wraps(func)
def wrapper(tree, *a, **k):
tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)
return func(tree, *a, **k)
return wrapper
@jax.jit
@jax.value_and_grad
@unfreeze_func
def loss_func(b: B):
return b(1.0)