ninjax icon indicating copy to clipboard operation
ninjax copied to clipboard

AttributeError: type object 'Module' has no attribute '__annotations__'

Open LYK-love opened this issue 1 year ago • 1 comments

Hi, I tried to run the script you provided with ninjax==2.3.3, the script is:

import flax
import jax
import jax.numpy as jnp
import ninjax as nj
import optax

# Supports all Flax and Haiku modules
Linear = nj.FromFlax(flax.linen.Dense)


class MyModel(nj.Module):

  lr: float = 0.01

  def __init__(self, size):
    self.size = size
    self.opt = optax.adam(self.lr)
    # Define submodules upfront
    self.h1 = Linear(128, name='h1')
    self.h2 = Linear(128, name='h2')

  def predict(self, x):
    x = jax.nn.relu(self.h1(x))
    x = jax.nn.relu(self.h2(x))
    # Define submodules inline
    x = self.get('h3', Linear, self.size, use_bias=False)(x)
    # Create state entries inline
    x += self.get('bias', jnp.zeros, self.size)
    return x

  def train(self, x, y):
    # Gradient with respect to submodules or state entries
    keys = [self.h1, self.h2, f'{self.path}/h3', f'{self.path}/bias']
    loss, params, grads = nj.grad(self.loss, keys)(x, y)
    # Update weights
    optstate = self.get('optstate', self.opt.init, params)
    updates, optstate = self.opt.update(grads, optstate)
    new_params = optax.apply_updates(params, updates)
    self.put(new_params)  # Store the new params
    return loss

  def loss(self, x, y):
    return ((self.predict(x) - y) ** 2).mean()


# Create model and example data
model = MyModel(3, lr=0.01, name='model')
x = jnp.ones((64, 32), jnp.float32)
y = jnp.ones((64, 3), jnp.float32)

# Populate initial state from one or more functions
state = {}
state = nj.init(model.train)(state, x, y, seed=0)
print(state['model/bias'])  # [0., 0., 0.]

# Purify and jit one or more functions
train = nj.pure(model.train)
train = jax.jit(train)

# Training loop
for x, y in [(x, y)] * 10:
  state, loss = train(state, x, y)
  print('Loss:', float(loss))

# Look at the parameters
print(state['model/bias'])  # [-1.2e-09  1.8e-08 -2.5e-09]

However, I met error

File [~/Projects/dreamerv3/dreamerv3/ninjax.py:418](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:418), in ModuleMeta.__new__(mcs, name, bases, clsdict)
    [415](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:415)     method_names.append(key)
    [416](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:416) cls = super(ModuleMeta, mcs).__new__(mcs, name, bases, clsdict)
    [417](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:417) cls.__field_defaults = {
--> [418](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:418)     k: getattr(cls, k) for k, v in cls.__annotations__.items()
    [419](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:419)     if hasattr(cls, k)}
    [420](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:420) for key, value in cls.__annotations__.items():
    [421](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:421)   setattr(cls, key, property(lambda self, key=key: self.__fields[key]))

AttributeError: type object 'Module' has no attribute '__annotations__'

LYK-love avatar Apr 06 '24 01:04 LYK-love

Needs a newer Python version.

danijar avatar Apr 06 '24 02:04 danijar