executorch icon indicating copy to clipboard operation
executorch copied to clipboard

UserWarning: Attempted to insert a get_attr Node .. when `to_backend` is called

Open mhs4670go opened this issue 1 year ago • 3 comments

Hello.

I'm trying to use executorch to convert below torch module to my backend binary. And, when I called to_backend, below warning message was printed.

/home/seongwoo/.venv/lib/python3.10/site-packages/torch/export/_unlift.py:58: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  getattr_node = gm.graph.get_attr(lifted_node)
/home/seongwoo/.venv/lib/python3.10/site-packages/torch/fx/graph.py:1460: UserWarning: Node y target y y of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
class SimpleDiv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # This line raise the UserWarning
        self.y = torch.randn(3,3)

    def forward(self, x):
        z = x / self.y
        return (z,)

    def get_example_inputs(self):
        # Set seed to control a random divisor to be non-zero.
        torch.manual_seed(1)
        return (torch.randn(3, 3),)

# ..

# model: ExportedProgram
module_edge = to_edge(model)
# When I called this line, the UserWarning is printed.
module_edge = module_edge.to_backend(MyPartitioner())

Well, of course I can ignore the warning but I want to check if the warning is intended. Because seems to me that, if the module is something that should be warned, the warning should have printed when I called export method in the first place as the warning code is under the torch/export/_unlift.py:58. Is it intended behavior of the to_backend?

mhs4670go avatar Apr 24 '24 04:04 mhs4670go

I think it's possible that the partitioner is a bit off, and the graph after to_backend is a bit off. For this example, the graph is

class GraphModule(torch.nn.Module):
    def forward(self, c_y: "f32[3, 3]", x: "f32[3, 3]"):
        # File: /tmp/ipykernel_894438/2453067640.py:8 in forward, code: z = x / self.y
        aten_div_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_div_Tensor(x, c_y);  x = c_y = None
        return (aten_div_tensor,)

Do you expect self.y owned by my_backend?

cccclai avatar Apr 25 '24 18:04 cccclai

@cccclai

I think it's possible that the partitioner is a bit off, and the graph after to_backend is a bit off

What do you mean the partitioner or the graph is off? It means being invalid?

Do you expect self.y owned by my_backend?

Yes. Then, should I set y as a buffer or parameter?

mhs4670go avatar Apr 26 '24 06:04 mhs4670go

What do you mean the partitioner or the graph is off? It means being invalid?

Yeah I'd guess so.

Yes. Then, should I set y as a buffer or parameter?

Yeah, if it's a constant, probably do register_buffer and then tag this node in the partitioner

cccclai avatar Apr 27 '24 04:04 cccclai

I get it:) Thank you for your comments. It would be better to use buffer and parameter for this.

mhs4670go avatar Apr 29 '24 09:04 mhs4670go