acme icon indicating copy to clipboard operation
acme copied to clipboard

Problems with MBOP in offline examples

Open ZhengyaoJiang opened this issue 1 year ago • 0 comments

Greetings, I met two problems with running MBOP baselines in https://github.com/deepmind/acme/blob/master/examples/offline/run_mbop_jax.py , and I'm looking for help.

The first one comes from: https://github.com/deepmind/acme/blob/c7aac29c40183a191d9c39e66fd80deea9299977/examples/offline/run_mbop_jax.py#L25

There isn't a module called helpers under acme.examples.offline. I got around it by using helpers from https://github.com/deepmind/acme/blob/master/examples/baselines/rl_continuous/helpers.py.

After that, the code can run for most of the gym-locomotion tasks but hopper (all the datasets), with traceback:

Traceback (most recent call last):
  File "/home/zhengyao/project/mbop/main.py", line 174, in <module>
    app.run(main)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/zhengyao/project/mbop/main.py", line 169, in main
    learner.step()
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/learning.py", line 199, in step
    self._world_model.step()
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/bc/learning.py", line 175, in step
    self._state, metrics = self._sgd_step(self._state, transitions)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 2158, in cache_miss
    out_tree, out_flat = f_pmapped_(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 2034, in pmap_f
    out = pxla.xla_pmap(
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/core.py", line 2022, in bind
    return map_bind(self, fun, *args, **params)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/core.py", line 2054, in map_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/core.py", line 2025, in process
    return trace.process_map(self, fun, tracers, params)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/core.py", line 687, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 829, in xla_pmap_impl
    compiled_fun, fingerprint = parallel_callable(
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/linear_util.py", line 295, in memoized_fun
    ans = call(fun, *args)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 857, in parallel_callable
    pmap_computation = lower_parallel_callable(
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1030, in lower_parallel_callable
    jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 937, in stage_parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2154, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2089, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/bc/learning.py", line 118, in sgd_step
    loss_result, gradients = loss_and_grad(network.apply, state.policy_params,
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 1070, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 2578, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/ad.py", line 134, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 800, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/learning.py", line 120, in loss_fn
    return loss(functools.partial(apply_fn, params), transitions)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/losses.py", line 63, in world_model_loss
    predicted_reward_t) = apply_fn(observation_t, action_t)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/ensemble.py", line 139, in apply_all
    return jax.vmap(base_apply)(params, *args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 1564, in vmap_f
    out_flat = batching.batch(
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 298, in apply_fn
    return f.apply(params, None, *args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/transform.py", line 354, in apply_fn
    out = f(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/networks.py", line 77, in _world_model_fn
    network(jnp.concatenate([observation_t, action_t], axis=-1)),
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 421, in wrapped
    out = f(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 271, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/nets/mlp.py", line 113, in __call__
    out = layer(out)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 421, in wrapped
    out = f(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 271, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/basic.py", line 176, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/base.py", line 522, in get_parameter
    raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'mlp/~/linear_0/w' with retrieved shape (23, 64) does not match shape=[14, 64] dtype=dtype('float32')

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/zhengyao/project/mbop/main.py", line 174, in <module>
    app.run(main)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/zhengyao/project/mbop/main.py", line 169, in main
    learner.step()
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/learning.py", line 199, in step
    self._world_model.step()
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/bc/learning.py", line 175, in step
    self._state, metrics = self._sgd_step(self._state, transitions)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/bc/learning.py", line 118, in sgd_step
    loss_result, gradients = loss_and_grad(network.apply, state.policy_params,
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/learning.py", line 120, in loss_fn
    return loss(functools.partial(apply_fn, params), transitions)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/losses.py", line 63, in world_model_loss
    predicted_reward_t) = apply_fn(observation_t, action_t)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/ensemble.py", line 139, in apply_all
    return jax.vmap(base_apply)(params, *args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 298, in apply_fn
    return f.apply(params, None, *args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn
    out, state = f.apply(params, {}, *args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/transform.py", line 354, in apply_fn
    out = f(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/networks.py", line 77, in _world_model_fn
    network(jnp.concatenate([observation_t, action_t], axis=-1)),
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 421, in wrapped
    out = f(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 271, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/nets/mlp.py", line 113, in __call__
    out = layer(out)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 421, in wrapped
    out = f(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 271, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/basic.py", line 176, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/base.py", line 522, in get_parameter
    raise ValueError(
ValueError: 'mlp/~/linear_0/w' with retrieved shape (23, 64) does not match shape=[14, 64] dtype=dtype('float32')

Process finished with exit code 1

ZhengyaoJiang avatar Aug 10 '22 14:08 ZhengyaoJiang