Explore as behavior for dreamerv3
Hi,
-
expl.py was copied from /dreamerv3 repo but nets.Input have new constructor now. In lines 16, 17: next part need to be removed as it case exception:
, dims='deter' -
In file behaviors.py was added reward scale, line 86. But it also comes to self.rewards so in line 102:
mets = rewfn.train(data)it throw exeption as lambda result not have .train method To fix that need to moveself.rewards[key] = rewfnbefore lambda cover or use another variable name for lambda scale -
Verry cool change for act_space flow (multy action is top) but they make some "critical" issue for Explore: I have simple env with act_space:
{'action': Space(dtype=int32, shape=(), low=0, high=15), 'reset': Space(dtype=bool, shape=(), low=False, high=True)}
In new act_space flow not present anymore OneHotAction wraper.
MLP for expl.Disag got dimension on construct (expl.py line 18):
[deter] + [stoch] + [action]
In my case its like:
Traced<ShapedArray(float16[32,31,1537])>with<DynamicJaxprTrace(level=1/0)>
But on loss calculation in train, action was converted to list length=15 and MLP got call like:
Traced<ShapedArray(float16[16,1024,1551])>with<DynamicJaxprTrace(level=1/0)>
So throw exeception: TypeError: dot_general requires contracting dimensions to have the same shape, got (1551,) and (1537,). in nets.py, line 563
For now as fix I just turn off [action] in disag_head.inputs but not sure how critical it change behavior.
With Best Regards, Yaroslav
I find out fix for point 3.
In file dreamerv3/agent.py need to replace line 95:
context = {**data, **outs}
need to replace with next block:
prev_state, prev_action = carry
prev_acts = {
k: jnp.concatenate([prev_action[k][:, None], data[k][:, :-1]], 1)
for k in self.act_space}
prev_acts = jaxutils.onehot_dict(prev_acts, self.act_space)
context = {**data, **outs, **prev_acts}