rfcs icon indicating copy to clipboard operation
rfcs copied to clipboard

RFC-0015-Tracer-module-scope-extension

Open BowenBao opened this issue 4 years ago • 16 comments

rendered

BowenBao avatar Jul 31 '21 01:07 BowenBao

Let's chat about this in VC some time; module scopes seem very similar to capabilities that FX provides, and I'm curious if we want to try hitting this with a more radical approach (e.g., rewriting ONNX to be an FX pass rather than a TorchScript trace based pass)

ezyang avatar Aug 16 '21 18:08 ezyang

@ezyang I've been told you have been invited to our regular meeting on 09-02 and we can discuss then. However this scope feature is needed in the short term and we don't want to block it on rewriting the entire ONNX export. Could we discuss the FX topic independently of moving forward with this for now?

garymm avatar Aug 16 '21 22:08 garymm

I finally had a chance to think about this proposal. I think I am primarily skeptical about the proposal is to add new attributes to Graph to record the scope information. If you're going to maintain a mapping from string scope name to extra args/kwargs, there's no reason this couldn't just have been done from Python, where when tracing you maintain this mapping. You push a unique scope string and then stash in the dictionary the extra args/kwargs you want to record. An added benefit about this change is that you don't have to deal with serializing Python arguments to IValues; you can record them verbatim (and store arbitrarily complicated stuff, like the nn.Module itself). WDYT? We can talk about this more on Tuesday too.

ezyang avatar Aug 23 '21 16:08 ezyang

@ezyang the Python side idea is appealing for its simplicity. We are considering something like this for the attribute side especially it they are scalars or python numbers. It is not immediately clear to me how to achieve the goal of identifying and preserving mapping between python args and IR Values with this approach.

Consider a simplified example for this purpose.

class Layer(torch.nn.Module):
    def forward(self, x, y):
        return ...

class Model(torch.nn.Module):
    def __init__(self):
        ...
        self.layer = Layer()

    def forward(self, input):
        x = ...  # produced from some module
        y = ...  # produced from some module
        return self.layer(x, y)
# IR snippet:
%x = ...
%y = ...
%out.1 = ONNX::layer_op(...)  # part of Layer module
%out.2 = ONNX::layer_op2(...)  # part of Layer module
%out.3 = ONNX::layer_op3(...)  # part of Layer module
return (%out.3)

Current IR does not explicitly show %x, %y are the Layer module inputs, and %out.3 is the Layer module output. Graph analysis can tell if Value is produced from or used by outer scope, but the original inputs/outputs ordering cannot be preserved. For our proposed approach we do not intend to preserve the IValue itself. The IValue is passed in through push_scope/pop_scope only to look up its IR Value from the tracer stash.

BowenBao avatar Aug 23 '21 21:08 BowenBao

Hmm. I'm not sure I entirely understand, but would it be helpful if you could retrieve the IR values while you were tracing (or maybe, some sort of token that you could use later to correspond to the graph)? Then you'd be able to trace Python arguments.

ezyang avatar Aug 24 '21 13:08 ezyang

We discussed this in our meeting and the key constraint from ONNX's side is that they want to be able to access the module scope information from C++. I'm not too sure what to do about this, TBH. It will be helpful to know what kinds of optimizations would actually make use of this information.

ezyang avatar Aug 24 '21 19:08 ezyang

@ezyang the main goal is to be able to preserve the inputs/outputs order when grouping nodes belong to a certain nn.Module as custom onnx function.

The current onnx pass is able to group nodes by looking at scope information. Taking the above example:

%out.1 = ONNX::layer_op(...)  # scope: Model.layer
%out.2 = ONNX::layer_op2(...)  # scope: Model.layer
%out.3 = ONNX::layer_op3(...)  # scope: Model.layer

can be converted to

%out = custom_domain::Layer(%x, %y)  # custom op representing class Layer

However original inputs order of the nn.Module call is not tracked and cannot be recovered.

BowenBao avatar Aug 27 '21 20:08 BowenBao

Hmm... what if we add some __torch_function__ style functionality so that you can interpose on the nn.Module call at the time you are tracing, and immediately do the custom_domain replacement, so that is the only thing that ever shows up in the trace?

ezyang avatar Aug 30 '21 03:08 ezyang

what if we add some torch_function style functionality so that you can interpose on the nn.Module call at the time you are tracing

Module hooks? Feel like global module hook (that fire for every Module that run) is what you're looking for in this case.

albanD avatar Aug 30 '21 17:08 albanD

Hmm... what if we add some __torch_function__ style functionality so that you can interpose on the nn.Module call at the time you are tracing, and immediately do the custom_domain replacement, so that is the only thing that ever shows up in the trace?

We'd also like to still trace the inner subgraph of the modules, since that will appear as the onnx local function subgraph definition.

BowenBao avatar Aug 31 '21 21:08 BowenBao

I guess what is giving me a lot of trouble here (and why the responses take a while), is that I don't really want to add a new fundamental, cross-cutting change to TorchScript IR with only one motivating use case for it. I'd rather take some much more circuitous route; for example, the original proposal with doing things in Python; what if you did the same thing, but transmitted the arguments all the way to C++ so they were accessible from your TorchScript passes? It's still hacky, but it feels a lot better.

ezyang avatar Sep 08 '21 17:09 ezyang

@ezyang That's a valid point, I guess we can explore not changing TorchScript IR, and try to manage things in Python. Instead of then push_scope/pop_scope proposal, we could possibly depend on _get_value_trace api from TracingState to retrieve the mapping of IValue <-> Value.

For the global module hook idea above, do you have some pointers or example snippets that we could give a try?

BowenBao avatar Sep 08 '21 23:09 BowenBao

^- @albanD

ezyang avatar Sep 13 '21 20:09 ezyang

@albanD @ezyang friendly ping for response. Is this the api for

global module hook (that fire for every Module that run)

In the link, the comment states

The input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward.

Is it possible to relax the constraint to support kwargs as well?

BowenBao avatar Sep 29 '21 22:09 BowenBao

Is it possible to relax the constraint to support kwargs as well?

We should; iirc the only reason it wasn't done was because it was BC breaking

ezyang avatar Sep 30 '21 13:09 ezyang

@BowenBao sorry I missed this message. Yes this is the right API. Adding kwargs has been discussed and is indeed a bit tricky due to BC but we're ok with adding a boolean flag when registering the hook to specify if the kwargs should be passed or not: https://github.com/pytorch/pytorch/pull/61606 (we could remove the lazy module part of this and add just the hook improvement).

albanD avatar Sep 30 '21 14:09 albanD