[Feature Request] Information of RNNs expected inputs and outputs difficult to access when part of larger architectures
Motivation
When RNN’s are used in isolation, creating a TensorDictPrimer Transform for the environment to populate the TensorDicts with the expected tensors is pretty straightforward:
from torchrl.modules import GRUModule
gru_module = GRUModule(
input_size=10,
hidden_size=10,
num_layers=1,
in_keys=["input", "recurrent_state", "is_init"],
out_keys=["output", ("next", "recurrent_state")],
)
transform = gru_module.make_tensordict_primer()
However, when RNN’s are part of a larger architecture, this can become tricky. e.g.
from torchrl.modules import GRUModule, MLP
from tensordict.nn import TensorDictModule, TensorDictSequential
gru_module = GRUModule(
input_size=10,
hidden_size=10,
num_layers=1,
in_keys=["input", "recurrent_state", "is_init"],
out_keys=["features", ("next", "recurrent_state")],
)
head = TensorDictModule(
MLP(
in_features=10,
out_features=10,
num_cells=[],
),
in_keys=["features"],
out_keys=["output"],
)
model = TensorDictSequential(gru_module, head)
In case you know the architecture, it is still possible to do:
transform = model[0].make_tensordict_primer()
But this is not ideal. Besides, beyond creating the transform automatically, maybe the user is interested in knowing the required shapes and other information of the model inputs, which now has the RNN inputs and their own inputs.
Solution
A solution would be to make possible to access all the information about the model expected inputs and outputs from some model specs.
Maybe it should not be required to define specs during the creation of the model, but optionally adding input specs would facilitate creating the primer transform in these cases.
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
The way I solved it in my case was to create a custom spec for every model I have and simply assign it to the model --> model.rnn_spec = spec. This way I can acess the info in other parts of the code.
What about
def get_primers_from_model(model):
primers = []
def make_primers(module):
if hasattr(module, "make_tensordict_primer"):
primers.append(module.make_tensordict_primer())
model.apply(make_primers)
if not primers:
raise smt
elif len(primers) == 1:
return primers[0]
else:
return Compose(*primers)