Documentation/Implementation mismatch in sum_independent_dims function
📚 Documentation
I've noticed a potential mismatch between the implementation and the documentation of the sum_independent_dims function in stable_baselines3.common.distributions.
According to its docstring, the function is designed to handle tensors with shapes (n_batch, n_actions) or (n_batch,), with the expected output shape being (n_batch,) in both scenarios. However, when given a 1D tensor, the implementation sums all elements into a scalar value, diverging from the expected behavior described.
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
"""
Continuous actions are usually considered to be independent,
so we can sum components of the ``log_prob`` or the entropy.
:param tensor: shape: (n_batch, n_actions) or (n_batch,)
:return: shape: (n_batch,)
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(dim=1)
else:
tensor = tensor.sum()
return tensor
This code demonstrates the unexpected behavior
>>> a = torch.zeros((16, 2))
>>> sum_independent_dims(a).shape
torch.Size([16]) # expected
>>> b = torch.zeros((16, ))
>>> sum_independent_dims(b).shape
torch.Size([]) # docstring suggests torch.Size([16])
To align the function's behavior with its documentation, I recommend either adjusting the implementation to ensure that a 1D tensor is returned unchanged (if the intention is to maintain the batch dimension across all scenarios), or updating the documentation to accurately reflect the specific use case where a 1D tensor is summed to a scalar (if this behavior is intended for internal use cases). I am not 100% sure which behavior was intended with this function, so further clarification from the developers would be greatly appreciated.
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] I have read the documentation
Hello, you are right, the doc is wrong. However, I would keep the behavior as-is to avoid any breaking change. A PR that updates the doc would be welcomed =)
I quickly looked int to the contributing guidelines, and I unfortunately don't have time to do all the steps right now. In the end it's two lines that would change in the documentation, so I'll just leave my proposed docstring here. If anyone want to incorporate that into a PR, feel free to :)
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
"""
Continuous actions are usually considered to be independent,
so we can sum components of the ``log_prob`` or the entropy.
# TODO: Potential issue with the implementation for the 1D input case.
:param tensor: shape: (n_batch, n_actions) or (n_batch,)
:return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(dim=1)
else:
tensor = tensor.sum()
return tensor
Hello, I went ahead and took the opportunity to put myself through the first PR exercise :)
Awesome, thank you! :)