trax
trax copied to clipboard
Why the error "NameError: name 'base' is not defined" happens.
when I run the official code,but the error "NameError: name 'base' is not defined" happens.
The code follow is the same as official code.
import numpy as np import trax from trax.fastmath import numpy as fastnp trax.fastmath.use_backend('jax')
matrix = fastnp.array([[1,2,3],[4,5,6],[7,8,9]]) print(f'matrix = \n{matrix}') vector = fastnp.ones(3) print(f'vector = {vector}') product = fastnp.dot(vector,matrix) print(f'product = {product}') tanh = fastnp.tanh(product) print(f'tanh(product) = {tanh}')
def f(x): return 2.0 * x * x grad_f = trax.fastmath.grad(f) print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
class Embedding(base.Layer): """Trainable layer that maps discrete token/IDs to vectors.""" def init(self, vocab_size, d_feature, kernel_initializer=init.RandomNormalInitializer(1.0)): """Return an embedding layer with given vocabulary size and vector size. Args: vocab_size:Size of the input vocabulary. The layer will assign a unique vector to each ID in 'range(vocab_size)'. d_feature:Dimensionality/depth of the output vectors. kernel_initializer:Function that creates (random) initial vectors for the embedding. """ super().init(name=f'Embedding_{vocab_size}_{d_feature}') self._d_feature = d_feature #feature dimensionality self._vocab_size = vocab_size self._kernel_initializer = kernel_initializer
def forward(self,x):
"""Returns embedding vectors corresponding to input token IDs
Args:
x:Tensor of token IDs
Returns:
Tensor of embedding vectors.
"""
return np.take(self.weights,x,axis=0)
def init_weights_and_state(self,input_signature):
"""Returns tensor of newly initialized embedding vectors. """
del input_signature
shape_w = (self._vocab_size, self._d_feature)
w = self._kernel_initializer(shape_w, self.rng)
self.weights = w
from trax import layers as tl #Create an input tensor x. x = np.arange(15) print(f'x={x}') #Create the embedding layer. embedding = tl.Embedding(vocab_size=20,d_feature=32) embedding.init(trax.shapes.signature(x))
#Run the layer -- y = embedding(x). y = embedding(x) print(f'shape of y = {y.shape}')