rfcs
rfcs copied to clipboard
RFC-0015-Tracer-module-scope-extension
Let's chat about this in VC some time; module scopes seem very similar to capabilities that FX provides, and I'm curious if we want to try hitting this with a more radical approach (e.g., rewriting ONNX to be an FX pass rather than a TorchScript trace based pass)
@ezyang I've been told you have been invited to our regular meeting on 09-02 and we can discuss then. However this scope feature is needed in the short term and we don't want to block it on rewriting the entire ONNX export. Could we discuss the FX topic independently of moving forward with this for now?
I finally had a chance to think about this proposal. I think I am primarily skeptical about the proposal is to add new attributes to Graph to record the scope information. If you're going to maintain a mapping from string scope name to extra args/kwargs, there's no reason this couldn't just have been done from Python, where when tracing you maintain this mapping. You push a unique scope string and then stash in the dictionary the extra args/kwargs you want to record. An added benefit about this change is that you don't have to deal with serializing Python arguments to IValues; you can record them verbatim (and store arbitrarily complicated stuff, like the nn.Module itself). WDYT? We can talk about this more on Tuesday too.
@ezyang the Python side idea is appealing for its simplicity. We are considering something like this for the attribute side especially it they are scalars or python numbers. It is not immediately clear to me how to achieve the goal of identifying and preserving mapping between python args and IR Values with this approach.
Consider a simplified example for this purpose.
class Layer(torch.nn.Module):
def forward(self, x, y):
return ...
class Model(torch.nn.Module):
def __init__(self):
...
self.layer = Layer()
def forward(self, input):
x = ... # produced from some module
y = ... # produced from some module
return self.layer(x, y)
# IR snippet:
%x = ...
%y = ...
%out.1 = ONNX::layer_op(...) # part of Layer module
%out.2 = ONNX::layer_op2(...) # part of Layer module
%out.3 = ONNX::layer_op3(...) # part of Layer module
return (%out.3)
Current IR does not explicitly show %x, %y are the Layer module inputs, and %out.3 is the Layer module output. Graph analysis can tell if Value is produced from or used by outer scope, but the original inputs/outputs ordering cannot be preserved. For our proposed approach we do not intend to preserve the IValue itself. The IValue is passed in through push_scope/pop_scope only to look up its IR Value from the tracer stash.
Hmm. I'm not sure I entirely understand, but would it be helpful if you could retrieve the IR values while you were tracing (or maybe, some sort of token that you could use later to correspond to the graph)? Then you'd be able to trace Python arguments.
We discussed this in our meeting and the key constraint from ONNX's side is that they want to be able to access the module scope information from C++. I'm not too sure what to do about this, TBH. It will be helpful to know what kinds of optimizations would actually make use of this information.
@ezyang the main goal is to be able to preserve the inputs/outputs order when grouping nodes belong to a certain nn.Module as custom onnx function.
The current onnx pass is able to group nodes by looking at scope information. Taking the above example:
%out.1 = ONNX::layer_op(...) # scope: Model.layer
%out.2 = ONNX::layer_op2(...) # scope: Model.layer
%out.3 = ONNX::layer_op3(...) # scope: Model.layer
can be converted to
%out = custom_domain::Layer(%x, %y) # custom op representing class Layer
However original inputs order of the nn.Module call is not tracked and cannot be recovered.
Hmm... what if we add some __torch_function__ style functionality so that you can interpose on the nn.Module call at the time you are tracing, and immediately do the custom_domain replacement, so that is the only thing that ever shows up in the trace?
what if we add some torch_function style functionality so that you can interpose on the nn.Module call at the time you are tracing
Module hooks? Feel like global module hook (that fire for every Module that run) is what you're looking for in this case.
Hmm... what if we add some
__torch_function__style functionality so that you can interpose on the nn.Module call at the time you are tracing, and immediately do thecustom_domainreplacement, so that is the only thing that ever shows up in the trace?
We'd also like to still trace the inner subgraph of the modules, since that will appear as the onnx local function subgraph definition.
I guess what is giving me a lot of trouble here (and why the responses take a while), is that I don't really want to add a new fundamental, cross-cutting change to TorchScript IR with only one motivating use case for it. I'd rather take some much more circuitous route; for example, the original proposal with doing things in Python; what if you did the same thing, but transmitted the arguments all the way to C++ so they were accessible from your TorchScript passes? It's still hacky, but it feels a lot better.
@ezyang That's a valid point, I guess we can explore not changing TorchScript IR, and try to manage things in Python.
Instead of then push_scope/pop_scope proposal, we could possibly depend on _get_value_trace api from TracingState to retrieve the mapping of IValue <-> Value.
For the global module hook idea above, do you have some pointers or example snippets that we could give a try?
^- @albanD
@albanD @ezyang friendly ping for response. Is this the api for
global module hook (that fire for every Module that run)
In the link, the comment states
The input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the
forward.
Is it possible to relax the constraint to support kwargs as well?
Is it possible to relax the constraint to support kwargs as well?
We should; iirc the only reason it wasn't done was because it was BC breaking
@BowenBao sorry I missed this message. Yes this is the right API. Adding kwargs has been discussed and is indeed a bit tricky due to BC but we're ok with adding a boolean flag when registering the hook to specify if the kwargs should be passed or not: https://github.com/pytorch/pytorch/pull/61606 (we could remove the lazy module part of this and add just the hook improvement).