functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Unable to use vmap atop torch.distribution functionality

Open agiachris opened this issue 2 years ago • 4 comments

Hello! I'm working on an application that requires computing a neural net's weight Jacobians through a torch.distribution log probability. Minimal example code show below:

import torch
from torch.distributions import Independent, Normal
from functorch import make_functional_with_buffers, jacrev, vmap

def compute_fischer_stateless_model(fmodel, params, buffers, input, target):
    input = input.unsqueeze(0)
    target = target.unsqueeze(0)
    pred = fmodel(params, buffers, input)
    normal = Independent(Normal(loc=pred, scale=torch.ones_like(pred)), reinterpreted_batch_ndims=1)
    log_prob = normal.log_prob(target)
    return log_prob

# Instantiate model, inputs, targets, etc.
fmodel, params, buffers = make_functional_with_buffers(model)
ft_compute_jac = jacrev(compute_fischer_stateless_model, argnums=1)
ft_compute_sample_jac = vmap(ft_compute_jac, in_dims=(None, None, None, 0, 0))
jac = ft_compute_sample_jac(fmodel, params, buffers, inputs, targets)

Executing my script returns a RuntimeError error of the form:

RuntimeError: vmap: It looks like you're either (1) calling .item() on a Tensor or (2) attempting to use a Tensor in some data-dependent control flow or (3) encountering this error in PyTorch internals. For (1): we don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. For (2): If you're doing some control flow instead, we don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . For (3): please file an issue.

Any help would be appreciated -- thanks in advance for you time!

agiachris avatar Jun 30 '22 22:06 agiachris

Thanks for the issue, @agiachris! Is there a specific model or inputs/targets you're running the above with? If not we can try to make some dummy models/inputs and try to reproduce.

zou3519 avatar Jul 05 '22 15:07 zou3519

Thanks for the quick response! The case I'm considering is in-fact a bit more contrived than the above. The model is a deep Q-network; code here: https://github.com/jhejna/research-lightning/blob/rl/research/networks/mlp.py#L29. While the actual function is written as so, wrapped with jacrev (has_aux=True) and vmap with in_dims being dynamically created to reflect the number of positional arguments *inputs passed to the model (in this case, two, for batched states / actions to compute the Q-value). Thanks again!

def _compute_fischer_stateless_model(self,
                                         fmodel: FunctionalModuleWithBuffers, 
                                         params: Tuple[nn.Parameter],
                                         buffers: Tuple[torch.Tensor],
                                         target = torch.Tensor,
                                         *input: Tuple[torch.Tensor],
                                         ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute the models weight Fischer for a single sample. There are two cases, below:
        1) Test weight Fischer: contribution C = (J_l.T @ J_l), J_l = d(-log p(y|x))/dw
        2) Empirical weight Fischer: contribution C = (J_f.T @ L_theta @ L_theta.T @ J_f), J_f = df(x)/dw

        args:
            fmodel: functional form of model casted from nn.Module
            params: parameters of functional model
            buffers: buffers of the functional model
            target: grouth truth target tensor
            *input: tuple of model input tensors
        
        returns:
            pre_jacobian: factor by which to compute the weight Jacobian of size (d)
            output: model predictions parameterizing the output distribution of size (d)
        """
        input = [x.unsqueeze(0) for x in input]
        outputs = fmodel(params, buffers, *input)
        pre_jacobians = self._output_dist.apply_sqrt_F(outputs) if not self._use_empirical_fischer \
            else -self._output_dist.log_prob(target.unsqueeze(0))
        
        return pre_jacobians.squeeze(0), outputs.squeeze(0)

agiachris avatar Jul 05 '22 15:07 agiachris

We haven't thought about how functorch composes with the torch.distributions package but yes we need to figure that out. More concretely, the problem here is a .all() call in distributions: https://github.com/pytorch/pytorch/blob/9d20af50608b146fe1c3296210a05cd8e4c60af2/torch/distributions/distribution.py#L55

zou3519 avatar Jul 06 '22 18:07 zou3519

Thanks for deducing this. Fortunately it isn't my primary use-case but would help for completion in the future!

agiachris avatar Jul 06 '22 18:07 agiachris