agents icon indicating copy to clipboard operation
agents copied to clipboard

Feature Request: support Priority Replay buffer with reverb in a parallel environment

Open Axel-CH opened this issue 3 years ago • 8 comments

First of All, Thank you for this awesome repo, It saved months of my life ;-)

I'm just letting you know that I'm following closely the ongoing implementation of the Priority Replay buffer with Reverb. I already tried to use it on my workload but unfortunately I can't, because i'm using the "parallel_py_environment".

Regards, Axel

Axel-CH avatar Aug 12 '20 02:08 Axel-CH

Do you mean you run into an error? If so, can you provide error logs for more context?

kuanghuei avatar Aug 14 '20 18:08 kuanghuei

Hello,

I'm using a custom "train_eval' function inspired from "https://github.com/tensorflow/agents/blob/master/tf_agents/agents/dqn/examples/v2/train_eval.py"

Here is the code that I use to initialize the envs: tf_env = tf_py_environment.TFPyEnvironment(parallel_py_environment.ParallelPyEnvironment(parrallel_envs_train)) Code that initialize the priority replay buffer:

            num_steps = 1000
            sequence_length = 2 # TODO set automatically

            priorityzed_table = reverb.Table(
                name='my_prioritized_experience_replay_buffer',
                sampler=reverb.selectors.Prioritized(0.8),
                remover=reverb.selectors.Fifo(),
                max_size=1000,
                rate_limiter=reverb.rate_limiters.MinSize(100),
            )
            
            rpbf_server = reverb.Server([priorityzed_table])
            rpbf_py_client = reverb.Client('localhost:{}'.format(rpbf_server.port))

            replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
                tf_agent.collect_data_spec,
                "my_prioritized_experience_replay_buffer",
                local_server=rpbf_server,
                sequence_length=sequence_length)

            traj_obs = reverb_utils.ReverbAddEpisodeObserver(
                rpbf_py_client, "my_prioritized_experience_replay_buffer", max_sequence_length=sequence_length)

            initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
                    tf_env,
                    collect_policy,
                    observers=[replay_buffer.add_batch] + train_metrics,
                    num_episodes=collect_episodes_per_iteration)
                    
            collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
                    tf_env,
                    collect_policy,
                    observers=[replay_buffer.add_batch] + train_metrics,
                    num_episodes=collect_episodes_per_iteration
                )

            data_spec = tf_agent.collect_data_spec

            # Before calling client.dataset ( took that from https://github.com/tensorflow/agents/issues/410 )
            get_dtype = lambda x: tf.as_dtype(x.dtype)
            get_shape = lambda x: (sequence_length,) + x.shape
            shapes = tf.nest.map_structure(get_shape, data_spec)
            dtypes = tf.nest.map_structure(get_dtype, data_spec)


            # Dataset generates trajectories
            dataset = replay_buffer.tf_client.dataset(
                'my_prioritized_experience_replay_buffer',
                dtypes=dtypes,
                shapes=shapes)

            iterator = iter(dataset)

So, When i execute the code, I'm getting this error:

Traceback (most recent call last):
  File "prepare_best_models_hps.py", line 190, in <module>
    app.run(main)
  File "/home/cuser/.local/lib/python3.8/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/cuser/.local/lib/python3.8/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "prepare_best_models_hps.py", line 166, in main
    train_eval(populated_space, save_models, generate_plots, generate_dataframes, trial_desc)
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/agents/dqn/hp_search_dqn/hp_search_dqn/__init__.py", line 664, in train_eval
    time_step, policy_state = collect_driver.run(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 787, in __call__
    result = self._call(*args, **kwds)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 830, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 702, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2948, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3319, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3171, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 613, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 191, in run
    return self._run_fn(
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/utils/common.py", line 185, in with_check_resource_vars
    return fn(*fn_args, **fn_kwargs)
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 218, in _run
    tf.while_loop(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in loop_body
    observer_ops = [observer(traj) for observer in self._observers]
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in <listcomp>
    observer_ops = [observer(traj) for observer in self._observers]
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/replay_buffers/replay_buffer.py", line 83, in add_batch
    return self._add_batch(items)
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/replay_buffers/reverb_replay_buffer.py", line 173, in _add_batch
    raise NotImplementedError(
NotImplementedError: ReverbReplayBuffer does not support `add_batch`. See `reverb_utils.ReverbObserver` for more information on how to add data to the buffer.
[reverb/cc/platform/default/server.cc:64] Shutting down replay server

Obviously the issue is coming from the way I initialize the observer with "add_batch". So I tried another way with this combination:

            traj_obs = reverb_utils.ReverbAddEpisodeObserver(
                rpbf_py_client, "my_prioritized_experience_replay_buffer", max_sequence_length=sequence_length)

            initial_collect_driver =  dynamic_episode_driver.DynamicEpisodeDriver(
                tf_env,
                collect_policy,
                observers=traj_obs,
                num_episodes=3)

And now I'm getting this error:

Traceback (most recent call last):
  File "prepare_best_models_hps.py", line 190, in <module>
    app.run(main)
  File "/home/cuser/.local/lib/python3.8/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/cuser/.local/lib/python3.8/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "prepare_best_models_hps.py", line 166, in main
    train_eval(populated_space, save_models, generate_plots, generate_dataframes, trial_desc)
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/agents/dqn/hp_search_dqn/hp_search_dqn/__init__.py", line 658, in train_eval
    time_step, policy_state = collect_driver.run(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 787, in __call__
    result = self._call(*args, **kwds)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 830, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 702, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2948, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3319, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3171, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 613, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 191, in run
    return self._run_fn(
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/utils/common.py", line 185, in with_check_resource_vars
    return fn(*fn_args, **fn_kwargs)
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 218, in _run
    tf.while_loop(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in loop_body
    observer_ops = [observer(traj) for observer in self._observers]
TypeError: 'ReverbAddEpisodeObserver' object is not iterable
[reverb/cc/platform/default/server.cc:64] Shutting down replay server

Regards,

Axel-CH avatar Aug 17 '20 07:08 Axel-CH

Just saw this. Taking a look.

ebrevdo avatar Aug 27 '20 04:08 ebrevdo

You're close; instead of passing traj_obs to observers, pass the list [traj_obs]. Let me know if that works.

ebrevdo avatar Aug 27 '20 04:08 ebrevdo

Ok, now i'm getting a new error:

[reverb/pybind.cc:416] Tensor can't be extracted from the source represented as ndarray: Invalid argument: Provided input could not be interpreted as an ndarray
Traceback (most recent call last):
  File "prepare_best_models_hps.py", line 191, in <module>
    app.run(main)
  File "/home/userone/.local/lib/python3.8/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/userone/.local/lib/python3.8/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "prepare_best_models_hps.py", line 166, in main
    train_eval(populated_space, save_models, generate_plots, generate_dataframes, trial_desc)
  File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/agents/dqn/hp_search_dqn/hp_search_dqn/__init__.py", line 664, in train_eval
    time_step, policy_state = collect_driver.run(
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 787, in __call__
    result = self._call(*args, **kwds)
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 830, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 702, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2948, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3319, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3171, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 613, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 191, in run
    return self._run_fn(
  File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/utils/common.py", line 185, in with_check_resource_vars
    return fn(*fn_args, **fn_kwargs)
  File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 218, in _run
    tf.while_loop(
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in loop_body
    observer_ops = [observer(traj) for observer in self._observers]
  File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in <listcomp>
    observer_ops = [observer(traj) for observer in self._observers]
  File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/replay_buffers/reverb_utils.py", line 143, in __call__
    self._writer.append(trajectory)
  File "/home/userone/anaconda3/envs/ag3/lib/python3.8/site-packages/reverb/client.py", line 150, in append
    self._writer.Append(tree.flatten(data))
TypeError: Append(): incompatible function arguments. The following argument types are supported:
    1. (self: reverb.libpybind.Writer, arg0: List[tensorflow::Tensor]) -> Status

Invoked with: <reverb.libpybind.Writer object at 0x7f27b25166b0>, [<tf.Tensor 'driver_loop/Placeholder_1:0' shape=(7,) dtype=int32>, <tf.Tensor 'driver_loop/Placeholder_4:0' shape=(7, 16, 6) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_5:0' shape=(7, 16, 6) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_6:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_7:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_8:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_9:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_10:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_11:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_12:0' shape=(7, 16, 5) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_13:0' shape=(7, 16, 20) dtype=float32>, <tf.Tensor 'driver_loop/clip_by_value_2:0' shape=(7,) dtype=int64>, <tf.Tensor 'driver_loop/add_1:0' shape=(7,) dtype=float32>, <tf.Tensor 'driver_loop/step/step_type:0' shape=(7,) dtype=int32>, <tf.Tensor 'driver_loop/step/reward:0' shape=(7,) dtype=float32>, <tf.Tensor 'driver_loop/step/discount:0' shape=(7,) dtype=float32>]

Note that I have a dictionnary prepocessing layer using multiple tf.keras.models.Sequential([list]) and a tf.keras.models.Sequential([list]) for preprocessing combiner. Apparently it's not a best practice according to this warning:

WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'list'> I don't know yet if the issue with reverb is related to that deprecated Sequential models usage

Axel-CH avatar Aug 27 '20 08:08 Axel-CH

I will rewrite my model using the functional API, and then retry to use reverb. Will keep you updated

Axel-CH avatar Aug 27 '20 16:08 Axel-CH

After a longer analysis, I don't think I have the possibility to fix the "sequential format" warning at my model level. Seem to be caused on a higher level in agent. Just to be more clear, here is an update on my issue:

When I execute the train_eval function using the priority replay buffer, i'm getting this warning/error combo:

WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a
<class 'list'> input: [<tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_3/flatten/Reshape:0' shape=(5, 6)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_2/flatten_1/Reshape:0' shape=(5, 6)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_9/flatten_3/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_8/flatten_2/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_7/flatten_7/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_6/flatten_6/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_5/flatten_4/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_4/flatten_5/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential/flatten_8/Reshape:0' shape=(5, 5)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_1/flatten_9/Reshape:0' shape=(5, 20)
dtype=float32>]
Consider rewriting this model with the Functional API.

W0831 11:13:24.864628 140579592951616 sequential.py:362] Layers in a Sequential model should only have a single input tensor, but we receive a
<class 'list'> input: [<tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_3/flatten/Reshape:0' shape=(5, 6)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_2/flatten_1/Reshape:0' shape=(5, 6)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_9/flatten_3/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_8/flatten_2/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_7/flatten_7/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_6/flatten_6/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_5/flatten_4/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_4/flatten_5/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential/flatten_8/Reshape:0' shape=(5, 5)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_1/flatten_9/Reshape:0' shape=(5, 20)
dtype=float32>] 
Consider rewriting this model with the Functional API.
[reverb/pybind.cc:416] Tensor can't be extracted from the source represented as ndarray: Invalid argument: Provided input could not be interpreted as an ndarray

(note: the warning is there also without the priority replay buffer but it's working fine, no error)

I think this error is raised because i'm using a dictionnary prepocessing layers using multiple inputs, and reverb can't extract the tensor from it.

Bellow you will see the structure of my preprocessing layers:

            preprocessing_layers = {
                'a': tf.keras.models.Sequential(list_of_layers['a']),
                'b': tf.keras.models.Sequential(list_of_layers['b']),
                'c': tf.keras.models.Sequential(list_of_layers['c']),
                'd': tf.keras.models.Sequential(list_of_layers['d']),
                'e': tf.keras.models.Sequential(list_of_layers['e']),
                'f': tf.keras.models.Sequential(list_of_layers['f']),
                'g': tf.keras.models.Sequential(list_of_layers['g']),
                'h': tf.keras.models.Sequential(list_of_layers['h']),
                'i': tf.keras.models.Sequential(list_of_layers['i']),
                'j': tf.keras.models.Sequential(list_of_layers['j']),
            }

Observation space: Dict(a:Box(16, 5), b:Box(16, 20), c:Box(16, 6), d:Box(16, 6), e:Box(16, 22), f:Box(16, 22), g:Box(16, 22), h:Box(16, 22), i:Box(16, 22), j:Box(16, 22))

Observation shape: {'a': (16, 5), 'b': (16, 20), 'c': (16, 6), 'd': (16, 6), 'e': (16, 22), 'f': (16, 22), 'g': (16, 22), 'h': (16, 22), 'i': (16, 22), 'j': (16, 22)}

Le me know what you think @ebrevdo

Axel-CH avatar Aug 31 '20 15:08 Axel-CH

Have you tried using tf_agents.network.Sequential instead of keras Sequential? Try that and report back.

ebrevdo avatar Apr 17 '21 20:04 ebrevdo