returnn icon indicating copy to clipboard operation
returnn copied to clipboard

`rf.control_flow_ctx` or not?

Open albertz opened this issue 2 years ago • 3 comments

The relative_positional_encoding implementation in RC uses:

with nn.control_flow_ctx(None):
  ...

This is relevant for graph-based backends, once we have control flow logic like nn.Cond and nn.Loop.

I wonder how we should deal with that. Should we have that? Is it simple to reason about this? In this case, we want that the computation happens only once, and also is cached. So when it is called within some loop, it should only execute once outside the loop. And when it is called again, from some other place, it should reuse the calculation from the first call.

In general, when the user writes pure RF code, and it works fine with e.g. the PT backend, does that imply that it should also work with any other backend?

We probably can never truly guarantee this, but I think we should try. I.e. avoid undefined behavior as much as possible, also not when the user does sth wrong. You often have undefined behavior when the user has some buggy code. It would be bad if this would not lead to an error but just works with some undefined behavior, because this undefined behavior would likely be different in a different backend then. So we should try to detect any wrong usage as much as possible (as reasonable).

But coming back to this specific example: It would not really be a bug to leave away nn.control_flow_ctx here. In an eager framework, there is not really such a concept. Although, we also have nn.Cond/nn.Loop in this case for RF code, so we know the control flow context. So we can still store the control flow context in the Tensor in all cases, and check if some invalid access happened? Or is this too much overhead?

If we decide to leave away nn.control_flow_ctx, how would we solve this then for graph-based frameworks? Somehow automatically figure out that the code is independent from the outer loop? Can we easily do that?

Originally posted by @albertz in https://github.com/rwth-i6/returnn/issues/1120#issuecomment-1479227790

albertz avatar Mar 27 '23 12:03 albertz

Note that the API around Cond/Loop (#1282) is very related here.

I am currently thinking that the best way forward is to avoid too much magic on RETURNN side and make it all very straightforward. This implies that we need to handle ControlFlowContext and also add control_flow_ctx. But its use should be anyway rare.

albertz avatar Mar 29 '23 09:03 albertz

So we can still store the control flow context in the Tensor in all cases, and check if some invalid access happened? Or is this too much overhead?

Here you were thinking about turning off all control flow logic in the Tensor implementation for eager frameworks?

I think if we really want to use the frontend to regularly switch the backend of our models, then the control flow has to be checked and correct for all backends.

Is it simple to reason about this?

Having rf.control_flow_ctx is not really great in the sense that especially for PyTorch but also for TF the control flow context is nothing which is normally exposed to the user. But probably ok if it's really used rarely. It's always possible to do the calculations outside the context and then pass it into it as a workaround, right? (That's of course ugly if in- and outside calculations logically belong together, like the attention and positional embedding here.)

In general, control flow context might be an argument for people against using the frontend in favour of pure PyTorch. As long as everything is working, just the interfaces of Cond and Loop are a bit inconvenient and that might be ok. But it would be frustrating to debug bugs in the control flow logic when using PyTorch. And the related logic in Dim.get_for_batch_ctx etc. is quite complex. I know it is planned to rework that, but some of it will remain.

patrick-wilken avatar Mar 29 '23 15:03 patrick-wilken

So we can still store the control flow context in the Tensor in all cases, and check if some invalid access happened? Or is this too much overhead?

Here you were thinking about turning off all control flow logic in the Tensor implementation for eager frameworks?

In an eager framework, rf.while_loop and rf.cond would basically just do the canonical code using while and if. Nothing more is really needed to make it work. But then, when you do an invalid access, e.g. you create a tensor inside the loop and access it from the outside, there would be no error, this would also just work. But this is an error in a graph-based framework. So, do we want that this will also lead to an error in an eager framework? Otherwise you might have pure RF code which you tested with PT and runs fine but which does not run in TF. I'm not sure how bad this is. It's maybe a rare situation, and the problem is easy to identify then. We could have checks for that, as I explained, but that would be a lot of overhead in terms of extra checks everywhere. There is already quite a bit of overhead, which is bad for performance for an eager-based framework.

Btw, regarding Dim.get_for_batch_ctx, this is #975. But also, I don't really want to use any of that for RF. But the details still might need to be worked out.

albertz avatar Mar 29 '23 21:03 albertz