flax icon indicating copy to clipboard operation
flax copied to clipboard

modifying params of flax.linen. Module model

Open lopa23 opened this issue 1 year ago • 1 comments

I am trying to add both a Flax.linen.Dense as well as Stax.Dense to a model(have to do this for something else). However the stax paramters do no show up in the model parameters. I have tried in vain to modify the params, but it keeps giving me error ValueError('FrozenDict is immutable.'). Any help is appreciated. Code is below from jax.example_libraries import stax import flax.linen as nn import jax import jax.numpy as jnp import optax from optax import adam from flax.core.frozen_dict import unfreeze class StaxLayerModule(nn.Module): def setup(self): # Initialize the stax layer and get initial parameters init_fn,self.apply_fn=stax.Dense(5) _, self.stax_params = init_fn(jax.random.PRNGKey(0), (5,))

    W=self.stax_params[0]
    b=self.stax_params[1]
    self.Dense1 = nn.Dense(5,kernel_init=nn.initializers.normal(.5))
    self.nn_params = self.Dense1.init(jax.random.PRNGKey(0), jnp.ones((10,)))
    self.nn_params = (unfreeze(dict(self.nn_params)))
    stx_param_dict=dict()
    stx_param_dict['kernel']=W
    stx_param_dict['bias']=b
    
    for key, value in stx_param_dict.items():
        self.nn_params[key + '_stax'] = value

    # Convert back to FrozenDict before using it as params
    self.nn_params = nn.freeze(self.nn_params)
   
def __call__(self, inputs):
    
    x=self.Dense1(inputs)
    
    x=self.apply_fn(self.stax_params, x)
    
   
    return x

Example usage

model = StaxLayerModule() input_data = jnp.full((3, 10),3) # Example input data

print("Input",input_data) ground_truth=jnp.zeros((3,5)) optimizer = adam(learning_rate= .01) rng = jax.random.PRNGKey(0) params=model.init(rng, input_data)

opt_state = optimizer.init(modified_params)

@jax.jit def mse(params, x, y): # Define the squared loss for a single pair (x,y) pred = model.apply(params, x) return jnp.mean((ground_truth-pred)**2)

loss_grad_fn = jax.value_and_grad(mse)

@jax.jit def update_params(grads, params, opt_state): updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state

@jax.jit def train_step(params, opt_state, model_input, ground_truth):

loss_val, grads = loss_grad_fn(params, model_input, ground_truth)
params, opt_state = update_params(grads, params, opt_state)
return params, opt_state

for step in range(11):

    params, opt_state=train_step(params,  opt_state, input_data, ground_truth)
   
    model_output =model.apply(params, input_data)
    
    loss_val=jnp.mean((ground_truth-model_output)**2)
    print(f'Loss step {step}: ', loss_val)
    if step % 10 == 0:
        print(model_output)
        if (loss_val < 0.0005):#.0005 for 256
            break

#output_data = model.apply(model.init(jax.random.PRNGKey(0), input_data), input_data) #print(output_data)

lopa23 avatar Feb 29 '24 23:02 lopa23

The code is a bit unreadable in its current format, but from the error it seems solvable by passing mutable=True in your apply() calls. See more in the documentation of apply.

IvyZX avatar Mar 05 '24 01:03 IvyZX