dreamerv3 icon indicating copy to clipboard operation
dreamerv3 copied to clipboard

Obtain World Model Predictions during Inference.

Open defrag-bambino opened this issue 4 months ago • 0 comments

Hi,

how may I obtain the predictions of the World Model during Inference? I have tried this command in a simple inference loop, but it throws an error: agent.agent.wm.imagine(agent.policy, obs, 10)

Error & Stacktrace

│                                                                             │
│   57 │   act = {'action': act['action'][0], 'reset': obs['is_last'][0]}     │
│   58 │                                                                      │
│   59 │   if i > 100:                                                        │
│ ❱ 60 │     agent.agent.wm.imagine(agent.policy, obs, 10)                    │
│   61                                                                        │
│   62                                                                        │
│   63                                                                        │
│                                                                             │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/ninjax.py:380 in wrapper    │
│                                                                             │
│   377   def wrapper(self, *args, **kwargs):                                 │
│   378 │   with scope(self._path, absolute=True):                            │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                          │
│   381   return wrapper                                                      │
│   382                                                                       │
│   383                                                                       │
│                                                                             │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/agent.py:183 in imagine     │
│                                                                             │
│   180                                                                       │
│   181   def imagine(self, policy, start, horizon):                          │
│   182 │   first_cont = (1.0 - start['is_terminal']).astype(jnp.float32)     │
│ ❱ 183 │   keys = list(self.rssm.initial(1).keys())                          │
│   184 │   start = {k: v for k, v in start.items() if k in keys}             │
│   185 │   start['action'] = policy(start)                                   │
│   186 │   def step(prev, _):                                                │
│                                                                             │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/ninjax.py:380 in wrapper    │
│                                                                             │
│   377   def wrapper(self, *args, **kwargs):                                 │
│   378 │   with scope(self._path, absolute=True):                            │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                          │
│   381   return wrapper                                                      │
│   382                                                                       │
│   383                                                                       │
│                                                                             │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/nets.py:34 in initial       │
│                                                                             │
│    31   def initial(self, bs):                                              │
│    32 │   if self._classes:                                                 │
│    33 │     state = dict(                                                   │
│ ❱  34 │   │     deter=jnp.zeros([bs, self._deter], f32),                    │
│    35 │   │     logit=jnp.zeros([bs, self._stoch, self._classes], f32),     │
│    36 │   │     stoch=jnp.zeros([bs, self._stoch, self._classes], f32))     │
│    37 │   else:                                                             │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/num │
│ py/lax_numpy.py:2317 in zeros                                               │
│                                                                             │
│   2314   if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise  │
│   2315   dtypes.check_user_dtype_supported(dtype, "zeros")                  │
│   2316   shape = canonicalize_shape(shape)                                  │
│ ❱ 2317   return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_t │
│   2318                                                                      │
│   2319 @util.implements(np.ones)                                            │
│   2320 def ones(shape: Any, dtype: DTypeLike | None = None, *,              │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/lax │
│ /lax.py:1226 in full                                                        │
│                                                                             │
│   1223 │   return dtype._rules.full(shape, fill_value, dtype)  # type: igno │
│   1224   weak_type = dtype is None and dtypes.is_weakly_typed(fill_value)   │
│   1225   dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))     │
│ ❱ 1226   fill_value = _convert_element_type(fill_value, dtype, weak_type)   │
│   1227   out = broadcast(fill_value, shape)                                 │
│   1228   if sharding is not None:                                           │
│   1229 │   return array.make_array_from_callback(shape, sharding, lambda id │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/lax │
│ /lax.py:560 in _convert_element_type                                        │
│                                                                             │
│    557 │   │      isinstance(core.get_aval(operand), core.ConcreteArray))): │
│    558 │   return type_cast(Array, operand)                                 │
│    559   else:                                                              │
│ ❱  560 │   return convert_element_type_p.bind(operand, new_dtype=new_dtype, │
│    561 │   │   │   │   │   │   │   │   │      weak_type=bool(weak_type))    │
│    562                                                                      │
│    563 def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) - │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:444 in bind                                                            │
│                                                                             │
│    441   def bind(self, *args, **params):                                   │
│    442 │   assert (not config.enable_checks.value or                        │
│    443 │   │   │   all(isinstance(arg, Tracer) or valid_jaxtype(arg) for ar │
│ ❱  444 │   return self.bind_with_trace(find_top_trace(args), args, params)  │
│    445                                                                      │
│    446   def bind_with_trace(self, trace, args, params):                    │
│    447 │   out = trace.process_primitive(self, map(trace.full_raise, args), │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:447 in bind_with_trace                                                 │
│                                                                             │
│    444 │   return self.bind_with_trace(find_top_trace(args), args, params)  │
│    445                                                                      │
│    446   def bind_with_trace(self, trace, args, params):                    │
│ ❱  447 │   out = trace.process_primitive(self, map(trace.full_raise, args), │
│    448 │   return map(full_lower, out) if self.multiple_results else full_l │
│    449                                                                      │
│    450   def def_impl(self, impl):                                          │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:935 in process_primitive                                               │
│                                                                             │
│    932   lift = sublift = pure                                              │
│    933                                                                      │
│    934   def process_primitive(self, primitive, tracers, params):           │
│ ❱  935 │   return primitive.impl(*tracers, **params)                        │
│    936                                                                      │
│    937   def process_call(self, primitive, f, tracers, params):             │
│    938 │   return primitive.impl(f, *tracers, **params)                     │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/dis │
│ patch.py:87 in apply_primitive                                              │
│                                                                             │
│    84   if xla_extension_version >= 218:                                    │
│    85 │   prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)     │
│    86 │   try:                                                              │
│ ❱  87 │     outs = fun(*args)                                               │
│    88 │   finally:                                                          │
│    89 │     lib.jax_jit.swap_thread_local_state_disable_jit(prev)           │
│    90   else:                                                               │
╰─────────────────────────────────────────────────────────────────────────────╯
XlaRuntimeError: INVALID_ARGUMENT: Disallowed host-to-device transfer: 
aval=ShapedArray(float32[]), dst_sharding=GSPMDSharding({replicated})

defrag-bambino avatar Feb 27 '24 12:02 defrag-bambino