dreamerv3
dreamerv3 copied to clipboard
Obtain World Model Predictions during Inference.
Hi,
how may I obtain the predictions of the World Model during Inference?
I have tried this command in a simple inference loop, but it throws an error: agent.agent.wm.imagine(agent.policy, obs, 10)
Error & Stacktrace
│ │
│ 57 │ act = {'action': act['action'][0], 'reset': obs['is_last'][0]} │
│ 58 │ │
│ 59 │ if i > 100: │
│ ❱ 60 │ agent.agent.wm.imagine(agent.policy, obs, 10) │
│ 61 │
│ 62 │
│ 63 │
│ │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │
│ │
│ 377 def wrapper(self, *args, **kwargs): │
│ 378 │ with scope(self._path, absolute=True): │
│ 379 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 380 │ │ return method(self, *args, **kwargs) │
│ 381 return wrapper │
│ 382 │
│ 383 │
│ │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/agent.py:183 in imagine │
│ │
│ 180 │
│ 181 def imagine(self, policy, start, horizon): │
│ 182 │ first_cont = (1.0 - start['is_terminal']).astype(jnp.float32) │
│ ❱ 183 │ keys = list(self.rssm.initial(1).keys()) │
│ 184 │ start = {k: v for k, v in start.items() if k in keys} │
│ 185 │ start['action'] = policy(start) │
│ 186 │ def step(prev, _): │
│ │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/ninjax.py:380 in wrapper │
│ │
│ 377 def wrapper(self, *args, **kwargs): │
│ 378 │ with scope(self._path, absolute=True): │
│ 379 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 380 │ │ return method(self, *args, **kwargs) │
│ 381 return wrapper │
│ 382 │
│ 383 │
│ │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/nets.py:34 in initial │
│ │
│ 31 def initial(self, bs): │
│ 32 │ if self._classes: │
│ 33 │ state = dict( │
│ ❱ 34 │ │ deter=jnp.zeros([bs, self._deter], f32), │
│ 35 │ │ logit=jnp.zeros([bs, self._stoch, self._classes], f32), │
│ 36 │ │ stoch=jnp.zeros([bs, self._stoch, self._classes], f32)) │
│ 37 │ else: │
│ │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/num │
│ py/lax_numpy.py:2317 in zeros │
│ │
│ 2314 if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise │
│ 2315 dtypes.check_user_dtype_supported(dtype, "zeros") │
│ 2316 shape = canonicalize_shape(shape) │
│ ❱ 2317 return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_t │
│ 2318 │
│ 2319 @util.implements(np.ones) │
│ 2320 def ones(shape: Any, dtype: DTypeLike | None = None, *, │
│ │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/lax │
│ /lax.py:1226 in full │
│ │
│ 1223 │ return dtype._rules.full(shape, fill_value, dtype) # type: igno │
│ 1224 weak_type = dtype is None and dtypes.is_weakly_typed(fill_value) │
│ 1225 dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) │
│ ❱ 1226 fill_value = _convert_element_type(fill_value, dtype, weak_type) │
│ 1227 out = broadcast(fill_value, shape) │
│ 1228 if sharding is not None: │
│ 1229 │ return array.make_array_from_callback(shape, sharding, lambda id │
│ │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/lax │
│ /lax.py:560 in _convert_element_type │
│ │
│ 557 │ │ isinstance(core.get_aval(operand), core.ConcreteArray))): │
│ 558 │ return type_cast(Array, operand) │
│ 559 else: │
│ ❱ 560 │ return convert_element_type_p.bind(operand, new_dtype=new_dtype, │
│ 561 │ │ │ │ │ │ │ │ │ weak_type=bool(weak_type)) │
│ 562 │
│ 563 def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) - │
│ │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:444 in bind │
│ │
│ 441 def bind(self, *args, **params): │
│ 442 │ assert (not config.enable_checks.value or │
│ 443 │ │ │ all(isinstance(arg, Tracer) or valid_jaxtype(arg) for ar │
│ ❱ 444 │ return self.bind_with_trace(find_top_trace(args), args, params) │
│ 445 │
│ 446 def bind_with_trace(self, trace, args, params): │
│ 447 │ out = trace.process_primitive(self, map(trace.full_raise, args), │
│ │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:447 in bind_with_trace │
│ │
│ 444 │ return self.bind_with_trace(find_top_trace(args), args, params) │
│ 445 │
│ 446 def bind_with_trace(self, trace, args, params): │
│ ❱ 447 │ out = trace.process_primitive(self, map(trace.full_raise, args), │
│ 448 │ return map(full_lower, out) if self.multiple_results else full_l │
│ 449 │
│ 450 def def_impl(self, impl): │
│ │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:935 in process_primitive │
│ │
│ 932 lift = sublift = pure │
│ 933 │
│ 934 def process_primitive(self, primitive, tracers, params): │
│ ❱ 935 │ return primitive.impl(*tracers, **params) │
│ 936 │
│ 937 def process_call(self, primitive, f, tracers, params): │
│ 938 │ return primitive.impl(f, *tracers, **params) │
│ │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/dis │
│ patch.py:87 in apply_primitive │
│ │
│ 84 if xla_extension_version >= 218: │
│ 85 │ prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) │
│ 86 │ try: │
│ ❱ 87 │ outs = fun(*args) │
│ 88 │ finally: │
│ 89 │ lib.jax_jit.swap_thread_local_state_disable_jit(prev) │
│ 90 else: │
╰─────────────────────────────────────────────────────────────────────────────╯
XlaRuntimeError: INVALID_ARGUMENT: Disallowed host-to-device transfer:
aval=ShapedArray(float32[]), dst_sharding=GSPMDSharding({replicated})