acme
acme copied to clipboard
Problems with MBOP in offline examples
Greetings, I met two problems with running MBOP baselines in https://github.com/deepmind/acme/blob/master/examples/offline/run_mbop_jax.py , and I'm looking for help.
The first one comes from: https://github.com/deepmind/acme/blob/c7aac29c40183a191d9c39e66fd80deea9299977/examples/offline/run_mbop_jax.py#L25
There isn't a module called helpers under acme.examples.offline
.
I got around it by using helpers from https://github.com/deepmind/acme/blob/master/examples/baselines/rl_continuous/helpers.py.
After that, the code can run for most of the gym-locomotion tasks but hopper (all the datasets), with traceback:
Traceback (most recent call last):
File "/home/zhengyao/project/mbop/main.py", line 174, in <module>
app.run(main)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/zhengyao/project/mbop/main.py", line 169, in main
learner.step()
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/learning.py", line 199, in step
self._world_model.step()
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/bc/learning.py", line 175, in step
self._state, metrics = self._sgd_step(self._state, transitions)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 2158, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 2034, in pmap_f
out = pxla.xla_pmap(
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/core.py", line 2022, in bind
return map_bind(self, fun, *args, **params)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/core.py", line 2054, in map_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/core.py", line 2025, in process
return trace.process_map(self, fun, tracers, params)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/core.py", line 687, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 829, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/linear_util.py", line 295, in memoized_fun
ans = call(fun, *args)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 857, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1030, in lower_parallel_callable
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 937, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2154, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2089, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/bc/learning.py", line 118, in sgd_step
loss_result, gradients = loss_and_grad(network.apply, state.policy_params,
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 1070, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 2578, in _vjp
out_primal, out_vjp = ad.vjp(
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/ad.py", line 134, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 800, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/learning.py", line 120, in loss_fn
return loss(functools.partial(apply_fn, params), transitions)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/losses.py", line 63, in world_model_loss
predicted_reward_t) = apply_fn(observation_t, action_t)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/ensemble.py", line 139, in apply_all
return jax.vmap(base_apply)(params, *args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/_src/api.py", line 1564, in vmap_f
out_flat = batching.batch(
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 298, in apply_fn
return f.apply(params, None, *args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/transform.py", line 354, in apply_fn
out = f(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/networks.py", line 77, in _world_model_fn
network(jnp.concatenate([observation_t, action_t], axis=-1)),
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 421, in wrapped
out = f(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 271, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/nets/mlp.py", line 113, in __call__
out = layer(out)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 421, in wrapped
out = f(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 271, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/basic.py", line 176, in __call__
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/base.py", line 522, in get_parameter
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'mlp/~/linear_0/w' with retrieved shape (23, 64) does not match shape=[14, 64] dtype=dtype('float32')
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/zhengyao/project/mbop/main.py", line 174, in <module>
app.run(main)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/zhengyao/project/mbop/main.py", line 169, in main
learner.step()
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/learning.py", line 199, in step
self._world_model.step()
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/bc/learning.py", line 175, in step
self._state, metrics = self._sgd_step(self._state, transitions)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/bc/learning.py", line 118, in sgd_step
loss_result, gradients = loss_and_grad(network.apply, state.policy_params,
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/learning.py", line 120, in loss_fn
return loss(functools.partial(apply_fn, params), transitions)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/losses.py", line 63, in world_model_loss
predicted_reward_t) = apply_fn(observation_t, action_t)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/ensemble.py", line 139, in apply_all
return jax.vmap(base_apply)(params, *args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 298, in apply_fn
return f.apply(params, None, *args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/transform.py", line 127, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/transform.py", line 354, in apply_fn
out = f(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/acme/agents/jax/mbop/networks.py", line 77, in _world_model_fn
network(jnp.concatenate([observation_t, action_t], axis=-1)),
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 421, in wrapped
out = f(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 271, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/nets/mlp.py", line 113, in __call__
out = layer(out)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 421, in wrapped
out = f(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/module.py", line 271, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/basic.py", line 176, in __call__
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
File "/home/zhengyao/miniconda3/envs/mbop/lib/python3.9/site-packages/haiku/_src/base.py", line 522, in get_parameter
raise ValueError(
ValueError: 'mlp/~/linear_0/w' with retrieved shape (23, 64) does not match shape=[14, 64] dtype=dtype('float32')
Process finished with exit code 1