ninjax
ninjax copied to clipboard
AttributeError: type object 'Module' has no attribute '__annotations__'
Hi, I tried to run the script you provided with ninjax==2.3.3, the script is:
import flax
import jax
import jax.numpy as jnp
import ninjax as nj
import optax
# Supports all Flax and Haiku modules
Linear = nj.FromFlax(flax.linen.Dense)
class MyModel(nj.Module):
lr: float = 0.01
def __init__(self, size):
self.size = size
self.opt = optax.adam(self.lr)
# Define submodules upfront
self.h1 = Linear(128, name='h1')
self.h2 = Linear(128, name='h2')
def predict(self, x):
x = jax.nn.relu(self.h1(x))
x = jax.nn.relu(self.h2(x))
# Define submodules inline
x = self.get('h3', Linear, self.size, use_bias=False)(x)
# Create state entries inline
x += self.get('bias', jnp.zeros, self.size)
return x
def train(self, x, y):
# Gradient with respect to submodules or state entries
keys = [self.h1, self.h2, f'{self.path}/h3', f'{self.path}/bias']
loss, params, grads = nj.grad(self.loss, keys)(x, y)
# Update weights
optstate = self.get('optstate', self.opt.init, params)
updates, optstate = self.opt.update(grads, optstate)
new_params = optax.apply_updates(params, updates)
self.put(new_params) # Store the new params
return loss
def loss(self, x, y):
return ((self.predict(x) - y) ** 2).mean()
# Create model and example data
model = MyModel(3, lr=0.01, name='model')
x = jnp.ones((64, 32), jnp.float32)
y = jnp.ones((64, 3), jnp.float32)
# Populate initial state from one or more functions
state = {}
state = nj.init(model.train)(state, x, y, seed=0)
print(state['model/bias']) # [0., 0., 0.]
# Purify and jit one or more functions
train = nj.pure(model.train)
train = jax.jit(train)
# Training loop
for x, y in [(x, y)] * 10:
state, loss = train(state, x, y)
print('Loss:', float(loss))
# Look at the parameters
print(state['model/bias']) # [-1.2e-09 1.8e-08 -2.5e-09]
However, I met error
File [~/Projects/dreamerv3/dreamerv3/ninjax.py:418](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:418), in ModuleMeta.__new__(mcs, name, bases, clsdict)
[415](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:415) method_names.append(key)
[416](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:416) cls = super(ModuleMeta, mcs).__new__(mcs, name, bases, clsdict)
[417](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:417) cls.__field_defaults = {
--> [418](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:418) k: getattr(cls, k) for k, v in cls.__annotations__.items()
[419](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:419) if hasattr(cls, k)}
[420](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:420) for key, value in cls.__annotations__.items():
[421](https://vscode-remote+ssh-002dremote-002b64-002e247-002e196-002e20.vscode-resource.vscode-cdn.net/home/lyk/Projects/dreamerv3/dreamerv3/~/Projects/dreamerv3/dreamerv3/ninjax.py:421) setattr(cls, key, property(lambda self, key=key: self.__fields[key]))
AttributeError: type object 'Module' has no attribute '__annotations__'
Needs a newer Python version.