tch-rs icon indicating copy to clipboard operation
tch-rs copied to clipboard

Equivalent of `register_forward_hook`

Open cpcdoy opened this issue 4 years ago • 5 comments

Hi,

I've been using hooks and especially register_forward_hook in PyTorch and wanted to be able to do the same in Rust.

Unless I missed something (I checked both torch-sys and tch-rs for a similar function), is there any way to emulate a hook using your API?

Thanks a lot

cpcdoy avatar Jun 20 '20 17:06 cpcdoy

I don't think there is any equivalent for this at the moment: the current api is mostly generated automatically from the declarations.yaml file and I haven't found anything related in it. Do you know if this could be done with the C++ api? Also maybe giving an example of your typical use case for this could help us see if there is a way to achieve the same kind of thing in tch-rs.

LaurentMazare avatar Jun 20 '20 19:06 LaurentMazare

So, I haven't used the C++ API before, so I'm not quite sure how to replicate that.

Here's how I've been using it in Python:

Let's say I have a DistilBert architecture:

...
      (0): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
...

Note: The architecture is cut to keep it brief

Now, let's say I want to run my model using the library in a standard way, like this:

model.encode('whatever')

But at the same time, by running the above line, I want to save the output of k_lin and q_lin that you can see above in the architecture summary.

The final goal is to store them in a list so that I can, for example for each layer, study each one of k_lin and q_lin's outputs independently without having to download the library's source code, modify it, make it output k_lin and q_lin and then maintain this modified version of the library on the side. That doesn't seem like a good idea.

In PyTorch, you'd do it using hooks this way:

  1. Create a list to keep the layer activations:
name_list = ['k_lin', 'q_lin']
NUM_LAYERS = 6
activations = {i : {name_ : [] for name_ in name_list} for i in range(NUM_LAYERS)}
  1. Create a hook
def create_hook(layer_, name_):
    def hook(model, input_, output_):
        activations[layer_][name_].append(output_.detach())
    return hook

# Access the q_lin layer and register a hook
for i in range(NUM_LAYERS):
    model.transformer.layer[i].attention.q_lin.register_forward_hook(create_hook(i, 'q_lin'))
    model.transformer.layer[i].attention.k_lin.register_forward_hook(create_hook(i, 'k_lin'))
  1. Use the library normally and the activations list will contain the output tensor of the forward pass of q_lin:
model.encode('whatever')

Hope this helps

cpcdoy avatar Jun 21 '20 14:06 cpcdoy

Just pinging the thread to see if there's any update on this

cpcdoy avatar Jun 30 '20 10:06 cpcdoy

@LaurentMazare It seems that there is a register_hook function in c++ api: https://github.com/pytorch/pytorch/blob/115494b00bf31549aa5227068bd66a3da9de469b/test/cpp/api/autograd.cpp#L547. Not sure whether it can be ported to tch-rs though.

NOBLES5E avatar Jul 24 '21 02:07 NOBLES5E

It boils down to this: https://github.com/pytorch/pytorch/blob/f52e2028409cd4bd23c31f64d0617703af31b1dc/aten/src/ATen/core/VariableHooksInterface.h#L63

NOBLES5E avatar Jul 24 '21 02:07 NOBLES5E