Multi-Agent implementation only checks the first agent when computing the finished-episode info
Discussed in https://github.com/Toni-SM/skrl/discussions/284
Originally posted by William-SKC February 27, 2025
Hi, I have a question regarding the expected behavior of this line in record_transition in MultiAgent class:
finished_episodes = (next(iter(terminated.values())) + next(iter(truncated.values()))).nonzero(as_tuple=False)
From my understanding, terminated and truncated are dictionaries where the keys represent agent IDs and the values are tensors indicating whether an agent's episode has ended.
My Questions: Does this line only check the first agent (since it uses next(iter(...))), rather than all agents? If so, shouldn't we be stacking all values using torch.stack(list(terminated.values())) to ensure all agents are checked? What is the expected output for finished_episodes? Should it track all finished agents or just the first one? I appreciate any clarification on this! Thanks for your help. 🙏