functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Support statically shaped dynamic operations (nonzero, unique, etc)

Open zou3519 opened this issue 4 years ago • 13 comments

🚀 Feature

Cross-post of https://github.com/pytorch/pytorch/issues/62320.

torch.nonzero, torch.unique, etc should accept an optional size argument to produce statically shaped output. We should identify all the PyTorch operations that need this treatment.

Motivation

(1) Make it so that they can be vmapped over (2) Make it easier to lower them to backends

Pitch

Either:

  • add a variant of these operations that accepts a size=None argument
  • add size=None to these operations.

Alternatives

N/A

Additional context

N/A

zou3519 avatar Nov 09 '21 15:11 zou3519

This is actionable: we should (1) identify all of the dynamic shaped operations in PyTorch (2) propose a modification to them and demonstrate that that enables us to vmap over them (3) get buy-in and (4) implement and ship

zou3519 avatar Nov 30 '21 15:11 zou3519

Hello, I've just been testing out a model of mine in FuncTorch and I've managed to get this error, and I'm here to report it!

RuntimeError: vmap: We do not support batching operators that can output dynamic shape. Attempted to vmap over aten::nonzero.Please voice your support in https://github.com/pytorch/functorch/issues/256

My script is effectively laid out as follows,

import torch
from torch import nn
from functorch import make_functional, vmap, grad, jacrev

from Models import vmapModel #my model

#args for data
nsamples=4096
ndim=1

#args for model
num_input=2
num_hidden=32
num_layers=1
num_dets=1
func=nn.Tanh()
pretrain=False

net = vmapModel(num_input=num_input,
                num_hidden=num_hidden,
                num_layers=num_layers,
                num_dets=num_dets,
                func=func,
                pretrain=pretrain)
                                 
x = torch.randn(nsamples, num_input, ndim)
                
func_model, params = make_functional(net)

def compute_loss(params, x):
  sgn, logabs = func_model(params, x)
  return logabs
  
def compute_hess(params, x):
  return jacrev(jacrev(compute_loss))(params, x).diagonal().sum()

per_sample_grads = vmap(jacrev(compute_hess), (None, 0))(params, x)

What I'm attempting to do is compute per-sample gradients of the laplacian with respect to the parameters (where the laplacian is of the output of my model w.r.t the inputs). The model (which I can share or make a smaller reproducible error if needed) is basically a feed-forward like network that outputs the sign and logabs of the network via a torch.slogdet call.

I don't understand the term dynamic shape but I assume it means a function whose output can change depending on some conditional? Is this correct?

AlphaBetaGamma96 avatar Feb 10 '22 18:02 AlphaBetaGamma96

Also hitting the error that points to this ticket when applying vmap to a function that uses torch.where

sergiynesterenko90 avatar Mar 22 '23 14:03 sergiynesterenko90

+1, trying to vmap() over unique()

enijkamp avatar Mar 29 '23 06:03 enijkamp

We were discussing this issue in issue triage, and one thing that wasn't clear to us was whether or not a size=... argument would actually help (as it would change what you would get, e.g., you would have to handle the resulting fill value if you didn't have enough entries, or you would have to be OK with truncating elements in that case.) So it's not a drop-in replacement, you have to know a bit about what exactly you are using nonzero/unique/where for.

Can folks here describe in more detail what they are using these operations for / post their model code? That would help us a lot in prioritizing this!

ezyang avatar Apr 03 '23 17:04 ezyang

@ezyang for torch.where: my goal was basically to support branching logic inside vmap.

A toy but representative example would be something like:

# suppose we have a list of random numbers between 0 and 1
R = tensor([random() for _ in range(100)])

indices_small = torch.where(R < 0.5)[0]
indices_large = torch.where(R >= 0.5)[0]

R_small = R[indices_small]
R_large = R[indices_large]

R_small = R_small * 2
R_large = R_large * 0.5

R_result = cat([R_small, R_large])

My real issue in my actual application is that the operations I do in either branch are expensive, and my current solution is to apply both operations to all types of numbers, and then mask out the correct result in the end.

Another example would be computing the collatz conjecture with the help of vmap.

Thinking on this though, I'm not sure that size= will help. If R_small / R_large are still the same size as R, I don't actually save any computation time. I don't think I could guarantee the proportion either. Maybe this really needs either full dynamic tensor size support, or an equivalent to jax.lax.cond. I can guarantee that my output (R_result) is always a static size (same as input) - maybe vmap could be relaxed to the point of forcing just a static size input / output but not internals with some AOT compilation? This doesn't sound easy...

For unique, I think size would actually help. I have a use case where I need to count how many unique entries I have in a tensor. If unique had a static size parameter, and I could choose to fill the remaining entries with some known quantity, I could do something like this:

import torch

def test(x: torch.Tensor):
    y = torch.unique(x, size=x.size(), fill=99)
    z = torch.where(y < 90, torch.tensor(1), torch.tensor(0))
    return torch.sum(z)

torch.vmap(test)(torch.tensor([[1, 2, 2, 3, 3], [1, 2, 3, 4, 5]]))

sergiynesterenko90 avatar Apr 04 '23 14:04 sergiynesterenko90

Yeah, so for your first case, size doesn't help. If the computations on the branches were cheap, you could just do R * 2 and R * 0.5 and then use torch.where to select between the results in one computation or the other. Even if they're not too cheap, with torch.compile we could fuse these into a single kernel to avoid memory bandwidth.

But if the compute on R is sufficiently expensive that it is profitable to avoid running it on the entire batch, then I think you have no choice to do a DtoH sync. But that's fine maybe: we aren't here to avoid DtoH sync, we just wanted to make vmap work. I think we can make it work: I think we could augment torch.cond to do what you want! If torch.cond has a data dependent condition, the way this should desugar is into a DtoH syncing operation to partition into the left and right sides, and then running batched operations on the branches, and then cat'ing them back together (we probably should preserve ordering, which means you need a more expensive interleaving operation). But in any case size doesn't help.

I agree that size for unique seems perfect in the example you gave.

cc @zou3519 @Chillee

ezyang avatar Apr 05 '23 18:04 ezyang

At least three folks on this thread have asked for unique, so that's something we can tackle first since we agree that a size= argument makes sense for it.

I'm not sure that vmap can always support the first case, even with the torch.cond extension. Here's an example:

def f(R):
  return cond(R < 0.5, lambda x: x * 2, lambda x: x * 0.5)

vmap(f)(R) works using the cond proposal: we can partition, send some elements to the left branch, and send the rest of the elements to the right branch.

However, vmap(vmap(f))(R) might not be happy, because for each element in the batch there is a variable amount of things that would need to flow to each branch and we don't have the ability to construct a Tensor with variable sizes (unless we start thinking of NestedTensor).

zou3519 avatar Apr 05 '23 18:04 zou3519

Well, vmap should be expressible as a source to source transform, right? So we just need to make sure that it transforms to something that we can keep going with. If that's nested tensor, then fine. But maybe there's something easier; it doesn't feel like this example should be worse with two batch dims rather than one.

ezyang avatar Apr 05 '23 19:04 ezyang

Hi All, I just had a look at my example above and I've resolved the issue. The problem was that I didn't specify the argnums in the jacrev call, so I assume it tried to jacrev over both the derivatives with respect to my input and the parameters. If I specify the argnums, it works fine. For reference, I'm using torch version 2.0.0.

AlphaBetaGamma96 avatar Apr 17 '23 13:04 AlphaBetaGamma96

Just a thought, there might be a temporary workaround for unique, that is to implement your own "unique" function. The down side is that it involves sort. Assume a 2D tensor X, where X[0,:] = 0.

def handmade_unique(x):
    sorted_x, _ = torch.sort(x)
    mask = (sorted_x[1:] - sorted_x[:-1])>0
    return sorted_x[1:]*mask

This works well with vmap for me.

padded_X = F.pad(X,(1,0))
new_X = torch.vmap(handmade_unique)(padded_X)

kanbei7 avatar Jul 31 '24 18:07 kanbei7

Would love to see this support for torch NMS: https://pytorch.org/vision/main/generated/torchvision.ops.nms.html#torchvision.ops.nms

aboubezari avatar Aug 15 '24 20:08 aboubezari

@aboubezari put this request in the vision repo!

ezyang avatar Aug 17 '24 13:08 ezyang