functorch icon indicating copy to clipboard operation
functorch copied to clipboard

25% Performance regression from v0.1.1 to 0.2.0 when calculating hessian

Open yueyericardo opened this issue 1 year ago • 40 comments

Hi developers,

After I upgraded functorch from v0.1.1 to 0.2.0, I noticed a 25% performance regression when calculating hessian, please check the following benchmark result and the attached benchmark script.

Please let me know if I did anything wrong, and also whether the perf regression could be fixed. Thanks!

Benchmark result

Benchmark result on NVIDIA A100

# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred       error: functorch: 0.00e+00
max hessian    error: functorch: 0.00e+00
reference_hessian: 61.837 ms
functorch_hessian: 29.474 ms

# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred       error: functorch: 1.49e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 62.519 ms
functorch_hessian: 39.666 ms  (0.75 X)

Benchmark result on NVIDIA A6000

# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred       error: functorch: 1.49e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 65.984 ms
functorch_hessian: 33.662 ms

# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred       error: functorch: 1.86e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 67.285 ms
functorch_hessian: 49.723 ms (0.68 X)

benchmark script

benchmark.py

import time
import argparse
from functorch import vmap, jacrev, jacfwd
import torch
import torch.nn as nn

torch.backends.cuda.matmul.allow_tf32 = False


_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2  # x, y
D2 = 3  # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False

model = nn.Sequential(
    nn.Linear(D1, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, D2),
).to(device)


def predict(x):
    torch.cuda.nvtx.range_push("forward")
    out = model(x)
    torch.cuda.nvtx.range_pop()
    return out, out  # return two outputs is needed for jacrev auxiliary object


def reference_hessian():
    x_ = x.clone().requires_grad_()
    ones = torch.ones(B, device=x.device)
    pred, _ = predict(x_)
    jacobian_rows = [None] * D2
    hessian_rows = [None] * (D2 * D1)
    for i in range(D2):
        torch.cuda.nvtx.range_push("autograd jacobian")
        jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[
            0
        ]
        torch.cuda.nvtx.range_pop()

    for i in range(D2):
        for j in range(D1):
            torch.cuda.nvtx.range_push("autograd hesian")
            hessian_rows[i * D1 + j] = torch.autograd.grad(
                jacobian_rows[i][:, j], x_, ones, create_graph=True
            )[0]
            torch.cuda.nvtx.range_pop()

    jacobian = torch.stack(jacobian_rows)  # [D2, B, D1]
    hessian = torch.stack(hessian_rows)  # [D2 * D1, B, D1]
    if run_backward:
        l = hessian.sum()
        l.backward()
    return hessian.transpose(0, 1), pred


def functorch_hessian():
    x_ = x.clone().requires_grad_()
    hessian, pred = vmap(
        jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
        in_dims=0,
    )(
        x_
    )  # [B, D2, D1, D1]
    if run_backward:
        l = hessian.sum()
        l.backward()
    return hessian, pred


def validate_result():
    # test functorch result
    ref_hes, ref_pred = reference_hessian()
    ft_hes, ft_pred = functorch_hessian()
    ref_hes = ref_hes.view_as(ft_hes)
    print(f"max pred       error: functorch: {(ref_pred - ft_pred).max():.2e}")
    print(f"max hessian    error: functorch: {(ref_hes - ft_hes).max():.2e}")


def benchmark(func):
    N = 20

    torch.cuda.synchronize()
    start = time.time()

    for i in range(N):
        torch.cuda.nvtx.range_push(func.__name__)
        _ = func()
        torch.cuda.nvtx.range_pop()

    torch.cuda.synchronize()
    time_ms = ((time.time() - start) / N) * 1000
    print(f"{func.__name__}: {time_ms:.3f} ms")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-b", "--backward", default=False, action="store_true")
    args = parser.parse_args()
    if args.backward:
        run_backward = True
        print("===== benchmark with backward =====")
    else:
        print("===== benchmark without backward =====")

    validate_result()

    # warm up
    for i in range(10):
        reference_hessian()
        functorch_hessian()

    # benchmark hessian
    benchmark(reference_hessian)
    benchmark(functorch_hessian)

yueyericardo avatar Jul 28 '22 08:07 yueyericardo

ping @samdow @zou3519

yueyericardo avatar Jul 28 '22 21:07 yueyericardo

Thanks for the report, we'll take a look soon

zou3519 avatar Jul 28 '22 21:07 zou3519

Bisected to https://github.com/pytorch/pytorch/pull/75195/files. https://github.com/pytorch/pytorch/pull/75195/files by itself may not be a problem, perhaps the problem is our batching rule for mv.

@yueyericardo is the repro you provided the entire model, or is it a subset of some model that you're running?

zou3519 avatar Jul 29 '22 01:07 zou3519

cc @lezcano @ezyang for https://github.com/pytorch/pytorch/pull/75195 -- this led to a performance regression in functorch. I'm not sure what the original intent of the PR is (there are no tests). I'm still trying to root cause this, but it is a bit difficult to visualize. What are the chances we could revert that PR?

(I've confirmed that reverting that single PR on pytorch/pytorch master makes the performance regression go away)

zou3519 avatar Jul 29 '22 01:07 zou3519

Great thanks to @zou3519 for the quick debugging!! I'm working on NVIDIA Modulus project, we are using functorch because it provides a lot of perfs for the Jacobian and Hessian calculations. The minimal repro I provided is only a subset of our model to demonstrate the performance regression.

Thanks again!

yueyericardo avatar Jul 29 '22 01:07 yueyericardo

@yueyericardo - for the original model itself, is the performance regression also 25%, or is it a smaller number? Is the original model public? One thing we can do to prevent future regressions is to check the original model into https://github.com/pytorch/benchmark.

I've noticed a lot of other similar models where folks have a nn.Sequential that is just made up of nn.Linear and activations and need to compute a vmap(jacrev or vmap(hessian of the quantity, so we could also potentially just check your script into torchbench otherwise.

zou3519 avatar Jul 29 '22 02:07 zou3519

Hi @zou3519 Our source code is free to download from the website, but it is not developed on GitHub. And our code base is also might too large to put into pytorch/benchmark.

Yes, exactly! I believe the minimal repro I provided is enough to prevent future regression for our model. Thanks!

yueyericardo avatar Jul 29 '22 02:07 yueyericardo

I think that rather than blindly reverting, we should get to the root of the problem, as it is very weird to get such a regression when dispatching from a more general function to a more concrete (that was the reason for that PR).

Things that come to mind are:

  • Is this regression architecture / cublas version - dependent?
  • Is this regresion also happening on regular matmul for that codepath?

If the answer to the above two is no, then this performance issue is likely on the functorch end and should be fixed. Otherwise, it's on the cuBLAS end and should be reported to NVIDIA

cc @ngimel @xwang233 @ivanyashchuk

lezcano avatar Jul 29 '22 07:07 lezcano

@Lezcano It's fair to submit the upstream bugs, but if we know that our upstream library's general kernel has better perf than a specialized one, we might as well use it.

ezyang avatar Jul 29 '22 12:07 ezyang

@Lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems?

Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else.

zou3519 avatar Jul 29 '22 14:07 zou3519

No, I don't think so. The PR is supposed to make the kernel run faster.

ezyang avatar Jul 29 '22 14:07 ezyang

fwiw, I think this may be related to the open PR I have to avoid copies in matmul. Could you check whether https://github.com/pytorch/pytorch/pull/76828 fixes this?

In any case, I'm fine with reverting, but we should investigate what's causing this regardless

lezcano avatar Jul 29 '22 16:07 lezcano

@Lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems?

Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else.

That PR is fixing spurious resize warnings that were previously generated, and by itself is supposed to speedup things by avoiding squeeze/unsqueeze calls which are not free (and especially not free when autograd is needed). As for more general/more concrete function performance, we should investigate this, but I doubt that's the case.

ngimel avatar Jul 29 '22 16:07 ngimel

@zou3519 can you by any chance collect profiling results for old and new versions?

ngimel avatar Jul 29 '22 16:07 ngimel

In any case, I'm fine with reverting, but we should investigate what's causing this regardless

I agree this warrants more investigation. We've got a problem in that there is a timeline for 1.12.1, and I am not sure how long it is going to take to actually get to the bottom of this.

@zou3519 can you by any chance collect profiling results for old and new versions?

I can try but it's been a long time since I touched nvprof or nsight profiler, so I will need to relearn the magic invocations.

but we should investigate what's causing this regardless

Since we changed the mm to mv, functorch generates different code for the vmap(jacrev(jacfwd(mm)) as opposed to vmap(jacrev(jacfwd(mv)). It's plausible that the problem is that "functorch should generate better code"; we're still digging into it

zou3519 avatar Jul 29 '22 17:07 zou3519

Isn't 1.12.1 done already? No nsight needed, just torch profiler should be enough. (with torch.profiler.profile as p(): and print key averages and export chrome trace in the end).

ngimel avatar Jul 29 '22 17:07 ngimel

FYI, the nsight profiling result before

Time	Total Time	Instances	Avg	Med	Min	Max	StdDev	Name
64.7%	928.624 ms	1000	928.623 μs	867.621 μs	297.345 μs	1.686 ms	493.406 μs	ampere_sgemm_128x64_nn
10.3%	147.304 ms	250	589.217 μs	589.187 μs	588.643 μs	590.435 μs	275 ns	ampere_sgemm_128x128_tn
9.4%	134.966 ms	1300	103.819 μs	70.432 μs	5.344 μs	194.209 μs	75.470 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, at::native::AddFunctor<float>>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
6.8%	98.146 ms	900	109.050 μs	94.848 μs	70.081 μs	177.921 μs	43.651 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, void at::native::threshold_kernel_impl<float>(at::TensorIteratorBase &, T1, T1)::[lambda(float, float) (instance 1)]>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
5.4%	77.550 ms	350	221.572 μs	300.321 μs	18.080 μs	302.050 μs	124.914 μs	ampere_sgemm_128x64_tn
0.8%	10.981 ms	1450	7.572 μs	3.552 μs	2.880 μs	26.112 μs	7.973 μs	void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.7%	9.949 ms	150	66.324 μs	82.928 μs	27.968 μs	88.672 μs	26.863 μs	ampere_sgemm_32x32_sliced1x4_nn
0.6%	8.074 ms	300	26.912 μs	26.976 μs	25.696 μs	27.776 μs	390 ns	void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(at::TensorIteratorBase &, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 8)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.5%	7.713 ms	50	154.257 μs	154.177 μs	153.184 μs	155.905 μs	634 ns	ampere_sgemm_32x128_nn
0.3%	4.847 ms	100	48.468 μs	47.664 μs	34.688 μs	63.393 μs	13.397 μs	ampere_sgemm_32x32_sliced1x4_tn
0.2%	3.342 ms	650	5.141 μs	5.344 μs	4.256 μs	5.728 μs	423 ns	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 8)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
0.1%	2.028 ms	50	40.569 μs	40.512 μs	40.032 μs	41.536 μs	272 ns	ampere_sgemm_64x64_nn
0.1%	994.305 μs	50	19.886 μs	19.856 μs	19.649 μs	20.896 μs	169 ns	ampere_sgemm_128x32_nn
0.0%	435.235 μs	100	4.352 μs	4.384 μs	4.160 μs	4.576 μs	100 ns	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::FillFunctor<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)

after (the third row 18.4% on a copy kernel)

Time	Total Time	Instances	Avg	Med	Min	Max	StdDev	Name
21.8%	433.409 ms	500	866.818 μs	866.756 μs	865.028 μs	870.116 μs	881 ns	ampere_sgemm_32x128_nt
21.2%	420.235 ms	250	1.681 ms	1.681 ms	1.678 ms	1.685 ms	813 ns	ampere_sgemm_128x64_nn
18.4%	365.797 ms	2850	128.349 μs	83.809 μs	4.096 μs	608.931 μs	173.023 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 14)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
10.8%	213.750 ms	300	712.501 μs	588.802 μs	588.195 μs	1.335 ms	277.174 μs	ampere_sgemm_128x128_tn
8.1%	160.586 ms	1300	123.527 μs	93.057 μs	5.089 μs	266.497 μs	92.181 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::CUDAFunctor_add<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
7.9%	156.152 ms	500	312.304 μs	312.290 μs	310.145 μs	314.689 μs	781 ns	ampere_sgemm_128x32_nn
5.9%	117.864 ms	900	130.960 μs	129.824 μs	77.504 μs	203.681 μs	49.220 μs	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, void at::native::<unnamed>::threshold_kernel_impl<float>(at::TensorIteratorBase &, T1, T1)::[lambda(float, float) (instance 1)]>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
3.3%	66.223 ms	50	1.324 ms	1.324 ms	1.324 ms	1.326 ms	397 ns	ampere_sgemm_128x128_tt
1.1%	21.094 ms	1850	11.402 μs	3.584 μs	2.816 μs	65.857 μs	12.662 μs	void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.4%	8.071 ms	300	26.904 μs	26.976 μs	25.792 μs	28.128 μs	392 ns	void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(at::TensorIteratorBase &, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 14)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.4%	7.903 ms	50	158.063 μs	158.049 μs	156.609 μs	159.617 μs	634 ns	ampere_sgemm_32x128_nn
0.4%	7.359 ms	100	73.594 μs	73.600 μs	72.000 μs	78.368 μs	1.223 μs	ampere_sgemm_32x32_sliced1x4_nt
0.2%	3.242 ms	50	64.844 μs	64.928 μs	62.464 μs	65.761 μs	528 ns	ampere_sgemm_32x32_sliced1x4_tn
0.1%	2.255 ms	100	22.551 μs	22.688 μs	20.865 μs	24.192 μs	1.106 μs	void gemmSN_NN_kernel<float, (int)256, (int)4, (int)2, (int)8, (int)3, (int)4, (bool)0, cublasGemvTensorStridedBatched<const float>, cublasGemvTensorStridedBatched<const float>, cublasGemvTensorStridedBatched<float>>(cublasGemmSmallNParams<T9, T10, T11, T1>)
0.1%	2.043 ms	100	20.427 μs	20.416 μs	20.160 μs	22.656 μs	257 ns	ampere_sgemm_128x32_tn
0.0%	432.292 μs	100	4.322 μs	4.336 μs	4.128 μs	4.640 μs	97 ns	void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::FillFunctor<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)

yueyericardo avatar Jul 29 '22 17:07 yueyericardo

@ngimel the PR that fixed the warnings was already merged, this one is just concerned about avoiding copies. One of the cases where it elides a copy is when you multiply a matrix by a batch of matrices. This is exactly the batched version of the vector-matrix product that"s causing the regression. That's why I think it may fix it

lezcano avatar Jul 29 '22 17:07 lezcano

That must be coming from functorch, as that PR doesn't introduce any additional copies.

ngimel avatar Jul 29 '22 17:07 ngimel

yup, I think https://github.com/pytorch/pytorch/pull/76828 would fix the regression

lezcano avatar Jul 29 '22 17:07 lezcano

I patched in https://github.com/pytorch/pytorch/pull/76828 and the above script ends up OOM-ing :/

zou3519 avatar Jul 29 '22 17:07 zou3519

Isn't 1.12.1 done already?

Not yet, we have a chance to change it (or request a 1.12.2 if necessary since this regression is large and these types of models are typical functorch usage)

zou3519 avatar Jul 29 '22 17:07 zou3519

Re. OOM. Wow, that's certainly unexpected. I'm not sure what's the best way to follow up on that. Probably @ngimel has a better idea how to proceed.

Regardless, landing that PR (the avoid copies matmul...) will be tricky due to some discrepancies in the accuracy of mm vs mv that we found for float16. As such, I think the most realistic way forward would be to revert the offending PR and investigate what's causing that OMM after 1.12.1 is released

lezcano avatar Jul 29 '22 18:07 lezcano

the PR that fixed the warnings was already merged

I don't think that's true, #75195 itself is fixing a warning (otherwise a user-supplied correct 1d out was sent to mm, and mm complained about resizing it to 2d).

ngimel avatar Jul 29 '22 18:07 ngimel

I think the one that fixed the out version was https://github.com/pytorch/pytorch/pull/75197 and a previous one in that stack, but it may be the case that 75195 also fixed an out warning, I don't remember now.

lezcano avatar Jul 29 '22 19:07 lezcano

Here is my attempt at a smaller repro that runs on a single version of pytorch. It computes a different quantity, vmap(jacrev(smaller_predict)), but also exhibits performance differences (1.3ms vs 7ms on my machine).

https://gist.github.com/zou3519/3869d460f8bcb12799967e08a5998d9c

In the gist there's also a trace (acquired via make_fx) of what the graph looks like, using the "old linear" (aka the implementation of matmul in 1.11) vs the "new linear" (the implementation of matmul in 1.12), if anyone's feeling ambitious about reading traces

zou3519 avatar Jul 29 '22 22:07 zou3519

@yueyericardo I got access to the modulus repo -- could you point me to which model in the modulus repo contains the nn.Sequential above please? We're figuring out how to check the above code into torchbench and I'm looking for some more context -- is there a name for computing the hessian for the nn.Sequential? Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?)

zou3519 avatar Aug 01 '22 21:08 zou3519

Hi @zou3519, thanks for the following up! We are having some internal discussions regarding this and will come back to you tomorrow.

yueyericardo avatar Aug 01 '22 23:08 yueyericardo

From looking at this a bit, I think what happened is:

  • pytorch changes unsqueeze + mm -> mv in case where first element is vector
  • this example hits that codepath under vmap => functorch's mv kernel
  • vector is batched and matrix is not => functorch moves the vector's batch dim to the end and matmuls. Batch dim is now the 1st dimension, not the 0th
  • relu's output is now BatchedTensor([512, 10000], lvl=0, dim=1). This needs to be saved as an activation for the backwards

During backwards,

  • grad_output is a BatchedTensor(BatchedTensor([10000, 512, 6], lvl=0, dim =0), lvl=1, dim=2)
  • threshold_backward (relu's backward) takes in grad_output and relu (binary pointwise batch rule). Starts with the lvl 1 batch rule
    • grad_output is the only value batched at the highest level, we move its batch dim to the front with a movedim. grad_output is no longer contiguous
    • redispatch on threshold_backward for lvl 0
    • move relu's bdim to front with a movedim and pad to be 3D. relu is no longer contiguous grad_output does not need to since it's lvl 0 bdim is 0. It is still contiguous
    • the view on relu is not contiguous, the view on grad_output is not contiguous => threshold_backward triggers copy Locally I've tested where only one of the inputs is not contiguous and haven't seen the same slowdowns.

We can also validate that this is our issue since we hit pre-regression performance numbers by changing functorch's mv batch rule here:

  auto other_ = moveBatchDimToFront(other, other_bdim);
  auto self_ = at::movedim(self, 0, 1);
  auto result = at::matmul(other_, self_);
  return std::make_tuple( std::move(result), 0 );

This doesn't trigger the copy since the batch dimension for the saved relu activation is now the first dimension. However, this may hit other perf issue from the transposing both self and result


This is the smallest subset of @zou3519's trace where I can see the perf differences

mm_2 = torch.randn(60000, 512)
relu = torch.randn(10000, 512)
def old_linear_faster():
    mm_2_view = mm_2.view([10000, 6, 1, 512])
    mm_2_view_squeeze = mm_2_view.squeeze(2)

    relu_view = relu.view([10000, 1, 512])

    threshold_backward = torch.ops.aten.threshold_backward(mm_2_view_squeeze, relu_view, 0)
    return threshold_backward

bmm = torch.randn(10000, 512, 6)
other_relu = torch.randn(512, 10000)
def new_linear_faster():
    bmm_view = bmm.view([10000, 512, 6])
    bmm_view_permute = bmm_view.permute([0, 2, 1])

    other_relu_permute = other_relu.permute([1, 0])
    other_relu_permute_view = other_relu_permute.view([10000, 1, 512])

    threshold_backward = torch.ops.aten.threshold_backward(bmm_view_permute, other_relu_permute_view, 0)
    return threshold_backward

notably, if If I change the views on either bmm or other_relu to be contiguous constants, it has much faster performance. So it seems like threshold_backwards doesn't copy if only one of its inputs is not contiguous but does if both are not

samdow avatar Aug 02 '22 13:08 samdow

Thanks @samdow, copy from relu definitely seems to affect perf, however there's also another copy coming from MvBackward Screen Shot 2022-08-02 at 10 13 10 AM

Also, why does threshold_backward on discontiguous inputs trigger a copy? In eager threshold_backward should be able to handle them via tensorIterator

ngimel avatar Aug 02 '22 17:08 ngimel