jaxrl
jaxrl copied to clipboard
Assert always evaluates to True
In these lines, because the assert statement is broken into multiple lines and hence covered with parentheses, the assert evaluates the tuple itself. So it will always evaluate to True.
https://github.com/ikostrikov/jaxrl/blob/36748627147c5914b3a0ee896c4313aadd062018/jaxrl/networks/common.py#L80-L81 https://github.com/ikostrikov/jaxrl/blob/36748627147c5914b3a0ee896c4313aadd062018/jaxrl/networks/common.py#L89-L90
This can be fixed by changing it to
assert loss_fn is not None or grads is not None, \
'Either a loss function or grads must be specified.'
assert has_aux, \
'When grads are provided, expects no aux outputs.'