flax icon indicating copy to clipboard operation
flax copied to clipboard

NNXWrapper

Open PhilipVinc opened this issue 1 year ago • 2 comments

Hi,

I have a large library that we decided to build on top of flax.linen several years ago. I'd like now to begin testing nnx. However, given the size of the repo and people using it, I cannot change everything at once over to nnx, instead I would like to keep using linen-style code for a while, and allowing users to use models defined with nnx inside of our library.

In brief, the way we use modules right now is

model = LinenModel(...)
model_state, parameters = fcore.pop(model.init(jax.random.key(1), ...), "params")
...
# jit boundary
variables ={"params": parameters, **model_state}
model.apply(variables, inoputs...)

I tried to use nnx.split to this end, but the way it works, returning a special object and not a simple dictionary, makes it impossible to have this approach work fine.

By inspecting nnx.compat/bridge I see that you have several utilities to use linen layers within nnx, but it is unclear to me how to do the opposite. It seems that nnx.bridge.NNXWrapper should do that, but it is unfinished, while it is not clear to me how to use nnx.Module..

Is there anything I can use?

PhilipVinc avatar Jul 18 '24 07:07 PhilipVinc

Hey @PhilipVinc, as you point out #4081 is the solution we are working on to use Linen Modules in NNX and vice versa. Should be done soon-ish. In the meantime maybe you can use something simple like:

class LinenToNNX(nnx.Module):
  def __init__(
    self,
    module: linen.Module,
    rngs: nnx.Rngs,
  ):
    self.module = module
    self.rngs = rngs
    self.initialized = False

  def __call__(
    self, *args: Any, **kwargs: Any
  ) -> Any:
    _rngs = {name: stream() for name, stream in rngs.items()}
    if 'params' not in _rngs and 'default' in _rngs:
      _rngs['params'] = _rngs.pop('default')
    
    if not self.initialized:
      self.initialized = True

      out, variables = self.module.init_with_output(_rngs, *args, **kwargs)
      self.params = nnx.Param(variables['params'])
    else:
      variables = {'params': self.params.value}
      out, variables = self.module.apply(variables, *args, rngs=_rngs, **kwargs)
      self.params.value = variables['params']

    return out

cgarciae avatar Jul 18 '24 10:07 cgarciae

Hi! I am working on the NNXToLinen wrapper that allows you to use NNX within Linen. I likely will send out the actual PR in a few days but for now this is my draft and example of use. Note the final API might be slightly different.

class NNXToLinen(nn.Module):
  module_op: Callable[..., nnx.Module]

  def setup(self):
    if self.is_initializing():
      self.module = self.module_op(rngs=nnx.Rngs(self.make_rng()))
      self.gdef, state = nnx.split(self.module)
      self.put_variable('params', 'nnx_state', state)
      return
    self.nnx_state = self.variable('params', 'nnx_state').value

  def __call__(self, *args, **kwargs):
    if self.is_initializing():
      return self.module(*args, **kwargs)
    module = nnx.eval_shape(self.module_op, rngs=nnx.Rngs(0))  # dummy rng
    nnx.update(module, self.nnx_state)
    return module(*args, **kwargs)

class NNXInner(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
    self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)
  
  def __call__(self, x):
    return x @ self.w.value

class LinenOuter(nn.Module):
  dout: int
  @nn.compact
  def __call__(self, x):
    linear = NNXToLinen(functools.partial(NNXInner, x.shape[-1], self.dout))
    b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout))
    return linear(x) + b

x = jax.random.normal(jax.random.key(0), (2, 4))
model = LinenOuter(3)
var = model.init(jax.random.key(0), x)
print(f'{var = }')
y = model.apply(var, x)
assert y.shape == (2, 3)
print(y)

IvyZX avatar Jul 18 '24 23:07 IvyZX