returnn icon indicating copy to clipboard operation
returnn copied to clipboard

RF cross_entropy (matmul, gather) should maybe have allow_broadcast?

Open albertz opened this issue 4 months ago • 0 comments

I had this bug:

log_prob = ...  # [B,T+1,D]
targets = ...  # [B,T] -> D
loss = rf.cross_entropy(target=targets, estimated=log_prob, ...)
loss.mark_as_loss(...)

What you get here is no error. It just works. I just wondered in my log that it doesn't seem to converge, but otherwise, all looked reasonable.

The bug is: loss here is [B,T+1,T], because T != T+1, so it gets broadcasted.

For other functions like rf.combine, rf.compare, rf.clip_by_value, rf.where, we have the arg allow_broadcast_all_sources. Or rf.concat has allow_broadcast. The default is False.

Note, the implementation of rf.cross_entropy in the general case just uses rf.gather or rf.matmul, which both don't have such an argument. For both matmul/gather, there are also many valid cases where broadcasting would be wanted (maybe broadcasting is also the wrong term, not sure). So if we would add such an argument, maybe the default would be True.

I'm not sure if there are many valid cases for rf.cross_entropy where broadcasting would make sense. So here the default should rather be False. I think it would break almost no setups, except my bugged code above, where it was unintentionally wrong anyway.

But I'm not sure.

albertz avatar Oct 19 '24 23:10 albertz