equinox icon indicating copy to clipboard operation
equinox copied to clipboard

[Question] Modifying the Static Variable in a Model

Open ahahn2813 opened this issue 1 year ago • 5 comments

Hello,

My advisor and I are attempting to do something non-typical with Equinox where we are trying to figure out how to change the static part of the neural network architecture on the fly.

Suppose we seek to learn the best activation function for a single layer of our neural network (all other architecture features are pre-chosen). The line of code below:

params,static = equinox.partition(model,equinox.is_array)

allows one to separate the model into a “static” variable and a “params” variable. The params variable contains all the weights, but the static variable contains the information regarding the activation function for the layer. As we will need to change the activation function when we update the architecture, we would like to know if it is possible to make modifications within the static variable? In other words, is there an easy way to convert the static variable to a params variable and then back to a static variable?

One way we can think of it is to convert part of the static variable into an array so we can modify it, but we do not know how to convert back to the static variable once it has been changed to an array. Thank you!

ahahn2813 avatar Aug 20 '24 19:08 ahahn2813

If you just want to update one member variable of the module, you can just use a tree at:

import jax
from jax import numpy as jnp
import equinox as eqx
from typing import Callable

class NN(eqx.Module):
  w: jax.Array
  b: jax.Array
  act_fn: Callable

  def __call__(self, x):
    return self.act_fn(self.w @ x + self.b)

net = NN(jnp.ones((10, 10)), jnp.ones(10), jax.nn.relu)
print(net(jnp.ones(10)))
print(net)
net = eqx.tree_at(lambda x: x.act_fn, net, jax.nn.sigmoid)
print(net(jnp.ones(10)))
print(net)

where the act_fn would be partitioned into static in your above code

lockwo avatar Aug 20 '24 19:08 lockwo

Thank you for your response, just so I understand properly. I can define variables inside the class corresponding to each quantity that I want dynamically change. For instance, number of layers, activation function and run a eqx.tree_at loop to identify what these values should be replaced as. For pseudocode

import jax
from jax import numpy as jnp
import equinox as eqx
from typing import Callable

class NN(eqx.Module):
  w: jax.Array
  b: jax.Array
  act_fn: Callable
  width: float

  def __init__(width, act):
         self.act_fn = act
         self.width = width
         
  def reintialize():
        self.w = ...
        self.b = ....
  def __call__(self, x):
    return self.act_fn(self.w @ x + self.b)


net = NN(width = 10, act =  jax.nn.relu)
print(net(jnp.ones(10)))
print(net)

for training network architecture loop:
         
         net = eqx.tree_at(lambda x: x.width, net, 5)
         net = eqx.tree_at(lambda x: x.act_fn, net, jax.nn.sigmoid)
         
         for training network weights loop
                 .....

As long as I have the right variable names within my net class i would be able to assign them on the fly with .tree_at(). The pseudo code might be crude but, have I understood the way you meant it?

krm9c avatar Aug 20 '24 21:08 krm9c

Sure, that would work. Although it seems like width impacts/determines other variables (such as w), but the code you run would work.

lockwo avatar Aug 20 '24 21:08 lockwo

This is exactly what we want to do, on the fly determine the architecture/hyperparameter.

We are trying to build some sort of neural architecture/hyperparameter search setup with equinox. This would be helpful in this regard.

Equinox is a wonderful library. Thank you for maintaining it and working on this. Thank you very much.

krm9c avatar Aug 20 '24 21:08 krm9c

I just answer a few issues, all credit goes to Patrick

lockwo avatar Aug 21 '24 06:08 lockwo