agents icon indicating copy to clipboard operation
agents copied to clipboard

Add ability to pass multiple inputs to a single preprocessing layer in EncodingNetwork

Open boomanaiden154 opened this issue 1 year ago • 5 comments

Currently, there is no way to pass multiple input tensors to an individual preprocessing layer. This isn't necessarily a large problem, but for some niche use cases, it can be quite useful, especially when writing custom models/layers as otherwise the data has to be piped in as a single tensor and then split into multiple tensors inside the custom model/layer which can present large inefficiencies if the two (or more) tensors that are being used have large shape mismatches.

This patch builds off of existing functionality in the EncodingNetwork (while making some modifications since a lot of the tf.nest.x functions don't support having tuples as keys) to make it so that using a tuple as a key in the preprocessing nest dictionary will grab all the inputs specified in the tuple and pass them along to the specified layer/model.

If this patch breaks any existing input combination that I'm not aware of, let me know and I should be able to add in a test for it/modify my code to make it work. If this patch isn't desired for some reason (ie maintenance, performance, something else), I can definitely work around it (very easy to subclass EncodingNetwork, just though that if I would put the effort into this, someone else might find it useful.

boomanaiden154 avatar Jul 26 '22 06:07 boomanaiden154

Instead of modifying the encoding network we recommend using Nest_Map to create networks with multiple inputs. See example here and here

sguada avatar Jul 28 '22 13:07 sguada

If I'm understanding everything correctly, it still doesn't look like it is possible to pass multiple inputs from an observation into a single preprocessing layer using NestMap. I'm currently able to pass multiple inputs into EncodingNetwork, I can just can't pass multiple inputs into a single preprocessing layer as each input tensor gets passed to a different preprocessing layer with the current behavior.

NestMap in it's call function eventually does this:

nested_layers_state = tf.nest.map_structure(
          lambda _: (), self._nested_layers)

Which is pretty similar to what EncodingNetwork is doing currently. Is there something I'm missing here?

boomanaiden154 avatar Jul 29 '22 01:07 boomanaiden154

Thinking about this a bit further, if it would be better to avoid modifying EncodingNetwork another possible solution would be to create something similar to NesMap called something like NestMapMultiInputLayers (probably would need a better name than that) which would allow for the increased functionality that I need. Would this approach work better and would PRs implementing this be desired/get accepted?

boomanaiden154 avatar Jul 31 '22 06:07 boomanaiden154

@sguada Any updates on whether or not I'm missing something in terms of how NestMap works and if you'd be willing to accept a PR in the area?

boomanaiden154 avatar Aug 22 '22 00:08 boomanaiden154

If you want to pass multiple inputs to the same Network (assuming the Network knows how to handle multiple inputs) what you need to do is nest the inputs appropiately.

Ex:

Inputs : {"inp1": (A, B), "inp2": C} Networks: {"inp1": net1, "inp2": net2}

net = sequential.Sequential([
        nest_map.NestMap(
            {'inp1': sequential.Sequential([tf.keras.layers.Concat(), tf.keras.layers.Dense(8)])
             'inp2': sequential.Sequential([tf.keras.layers.Dense(16)])}
    inputs = {
        'inp1': (tf.ones((8, 10, 3), dtype=tf.float32), tf.zeros((8, 10, 3), dtype=tf.float32))
        'inp2': tf.ones((8, 10, 8), dtype=tf.float32),
    }
    output, next_state = net(inputs)

The code that actually compute the outputs uses map_structure_up_to which would do the map to the correct level of grouping.

 outputs_and_next_state = nest_utils.map_structure_up_to(
        self._nested_layers, _mapper,
        inputs, self._nested_layers, nested_layers_state)

sguada avatar Aug 22 '22 12:08 sguada