returnn
returnn copied to clipboard
RF cross_entropy (matmul, gather) should maybe have allow_broadcast?
I had this bug:
log_prob = ... # [B,T+1,D]
targets = ... # [B,T] -> D
loss = rf.cross_entropy(target=targets, estimated=log_prob, ...)
loss.mark_as_loss(...)
What you get here is no error. It just works. I just wondered in my log that it doesn't seem to converge, but otherwise, all looked reasonable.
The bug is: loss
here is [B,T+1,T], because T != T+1, so it gets broadcasted.
For other functions like rf.combine
, rf.compare
, rf.clip_by_value
, rf.where
, we have the arg allow_broadcast_all_sources
. Or rf.concat
has allow_broadcast
. The default is False.
Note, the implementation of rf.cross_entropy
in the general case just uses rf.gather
or rf.matmul
, which both don't have such an argument. For both matmul
/gather
, there are also many valid cases where broadcasting would be wanted (maybe broadcasting is also the wrong term, not sure). So if we would add such an argument, maybe the default would be True.
I'm not sure if there are many valid cases for rf.cross_entropy
where broadcasting would make sense. So here the default should rather be False. I think it would break almost no setups, except my bugged code above, where it was unintentionally wrong anyway.
But I'm not sure.