equinox
equinox copied to clipboard
Why uniform bias initializer?
I noticed that the bias parameters in linear and conv modules use a uniform initializer. Is there a good justification for this? I noticed that PyTorch does this to. I was expecting a zero initializer, which is also what Flax uses.
The justification is really just "that is what PyTorch does". (I copied the initialisers over when making Equinox.)
If you want a zero bias then this can be done by replacing the leaf:
linear = eqx.nn.Linear(...)
linear = eqx.tree_at(lambda l: l.bias, linear, jnp.zeros_like(linear.bias))
Maybe this deserves a separate issues / feature request, but it was non-obvious to me as a first-time user (coming from Haiku) how to overwrite the default initialization scheme in nn.Linear
.
I think it would be cool if there was a w_init
argument (similar to Haiku), or perhaps just a comment in the docs explaining what is the most idiomatic approach to do this.
For what its worth, I find the eqx.tree_at
approach to be a bit "too clever". Like, its not immediately obvious how this works to a new user. Maybe that's just me though.
As an aside, I am really enjoying Equinox so far and appreciate all the work you have put in to it!
Thanks, I'm glad you like it!
Regarding adding this to the docs: yep, this is on the horizon (#185).
Regarding making this an additional argument: so far I've been resisting this as it adds quite a lot of complexity. It would add a lot of arguments that I think most folks don't use (PyTorch doesn't have this). It would necessitate adding a new eqx.nn.init
namespace to hold the existing initialisers. Both of the above clutter the documentation.
But that isn't a strong feeling on my part. This request does come up every now and again, and I'd be happy to be persuaded otherwise?
(cc @typedfemale as we were talking about this recently)
Is there a good way to change all bias initializers to zero?
That is a fair point!
It is true that PyTorch doesn't take a w_init
, but because stuff is mutable in PyTorch its very easy to just set linear.bias = ...
(or do an in-place modify or whatever).
But I guess this is just teething pains of adjusting to immutability / PyTrees.
I will give this more thought as I adjust to the eqx.tree_at
model surgery approach!
Is there a good way to change all bias initializers to zero?
+1 for this. Right now I have a hacky method relying on recursive getattr
to first get a sequence of target attribute strings and then pass them to eqx.tree_at
for replacement. Would like to know of a cleaner method to achieve this.
So I can think of two non-hacky methods for setting all biases to zero.
Option 1:
def linear(*args, **kwargs):
out = eqx.nn.Linear(*args, **kwargs)
out = eqx.tree_at(lambda l: l.bias, out, replace_fn=jnp.zeros_like)
return out
and then using linear
(or an analogous conv2d
etc.) everywhere you were previously using eqx.nn.Linear
.
Option 2:
model = ...
has_bias = lambda x: hasattr(x, "bias")
where = lambda m: [x.bias for x in jax.tree_util.tree_leaves(m, is_leaf=has_bias) if has_bias(x)]
model = eqx.tree_at(where, model, replace_fn=jnp.zeros_like)
FWIW I think I am hearing consensus in favour of this proposal! So perhaps let's just add it :D
The idea would be to:
- Add an
eqx.nn.init
module with all the current initialisers. - Add extra arguments
weight_init=...
,bias_init=...
toeqx.nn.{Linear, Conv, MultiheadAttention, ...}
. - Make sure that any of the "more complicated" layers like
MLP
,MultiheadAttention
forward such arguments to their sublayers. - Add tests.
- Add the
eqx.nn.init
namespace to the documentation.
This is a fair amount of work so this probably isn't something I'll find time to implement in the near future. If anyone involved in this thread feels suitably motivated, then I'd be very happy to accept a PR on this.
I'm not pushing to add init args. I agree that it adds unnecessary complexity (and work :sweat_smile:)
Option 2
covers my use-cases with a minor modification; if has_bias(x) and x.bias is not None
. Thanks!
Does either option work with nn.MLP or would I need to do more surgery? It's unclear to me how I would use tree_at to access each linear layer but maybe my pytree skills are just weak...
Approach 1: the above "Option 2" should work out-of-the-box: it will detect each linear layer as these all have a bias
attribute.
Approach 2: if you wanted, you could also switch out has_bias
with is_linear = lambda x: isinstance(x, eqx.nn.Linear)
, and then do exactly as before. As eqx.nn.MLP
uses linear layers internally, and all its linear layers are the ones with biases, then this will work equally well.
Approach 3: if you're working with an MLP specifically then you could also just list out all its linear layers manually:
mlp = eqx.nn.MLP(...)
where = lambda m: [lin.bias for lin in m.layers]
mlp = tree_at(where, mlp, replace_fn=jnp.zeros_like)
Any one of these approaches is equally fine.
Ultimately I think this kind of model surgery is one of the greatest strengths of Equinox. It takes a little getting used to, but variations on the above pattern allow you to perform almost any kind of adjustment to your model.
I've been trying to experiment with different initialization schemes for MLPs and came across this issue. Is there a simple way to use tree_at
to apply e.g. Lecun initialization to all weights and biases?
Something like:
import equinox as eqx
import jax.random as jr
import jax.tree_util as jtu
from jaxtyping import Array, Float
key = jr.PRNGKey(...)
model = ... # your model
def lecun_init(weight: Float[Array, "out in"], key: jr.PRNGKey) -> Float[Array, "out in"]:
out, in_ = weight.shape
stddev = math.sqrt(1 / in_)
return stddev * jr.truncated_normal(key, lower=-2, upper=2)
is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_weights = lambda m: [x.weight for x in jtu.tree_leaves(m, is_leaf=is_linear) if is_linear(x)]
weights = get_weights(model)
new_weights = [lecun_init(weight, subkey) for weight, subkey in zip(weights, jr.split(key, len(weights)))]
new_model = eqx.tree_at(get_weights, model, new_weights)
(The jaxtyping annotations are just a nice-to-have, they don't affect runtime.)
Add extra arguments weight_init=..., bias_init=... to eqx.nn.{Linear, Conv, MultiheadAttention, ...}.
May I suggest having instead one init
dict argument. Then passing initializers to nested sublayers could follow the same hierarchy as sublayers.
This should have been in #622 but I put it here as this issue is still open.
in jax world, it is all about transformation, with eqx has already done model = some_map(model)
, then the layer init should be layer = eqx.init(layer, key, weight=default_init, bias=default_init)
def init(pytree, key, **kwargs):
for attr, init_method in kwargs.items():
key, init_key = jrandom.split(key)
pytree = eqx.tree_at(
lambda t: getattr(t, attr, None),
pytree,
replace_fn=functools.partial(init_method, key=init_key),
)
return pytree
class MyModule(nn.Module):
proj: nn.Linear
def __init__(self, *, key=None):
self.proj = eqx.init(nn.Linear(in_feat, out_feat), key, weight=xavier_init, bias=optional_default_init) # optional_default_init accepts a None
But how to init large models (might rely on distributed rng behavior, tho) remain open.