Support vectorize/devectorize inside gradients
Thanks for making Nx!
I tried to use value_and_grad on a function that takes two inputs: a vectorized tensor and a non-vectorized tensor.
defmodule Foo do
import Nx.Defn
defn f(x, y) do
x + y
end
defn f_and_grad(x, y) do
value_and_grad(y, fn y -> Foo.f(x, y) end)
end
end
x = ~VEC[0 1] |> vectorize(:bar)
Foo.f_and_grad(x, 1)
This evaluates to:
{#Nx.Tensor<
vectorized[bar: 2]
s64
EXLA.Backend<host:0, 0.731981912.321781778.128426>
[1, 2]
>,
#Nx.Tensor<
f32
EXLA.Backend<host:0, 0.731981912.321781778.128427>
2.0
>}
The value is correct and maintains the vectorized axis of the vectorized input to x, but the gradient surprises me. I would have expected a vectorized tensor rank-1 dimension-2 vector with the same :foo axis and which is everywhere 1; it looks like instead Nx is summing up the two gradients.
Is this behavior expected? If so, is there any way to make Nx return a vectorized gradient?
Thanks!
I know that I can use
y = ~VEC[1 1] |> vectorize(:bar)
Foo.f_and_grad(x, y)
to get the result I expect, but in practice y is actually quite large, so repeating it just so the gradient is computed properly seems wasteful. I will dig into that more though.
I think this makes sense because the grad is computed over y, but I would like to see if @polvalente has a different opinion.
I tried checking if it would still be efficient to broadcast y to the size of x in order to get a gradient with the same dimensions as y; I wasn't sure whether Nx would create e.g. a vector with zero stride. However it looks like the byte_size increases, at least with Nx.BinaryBackend and Nx.EXLABackend:
x = ~VEC[0 0] |> vectorize(:foo)
y = ~VEC[1]
[x, y] = Nx.broadcast_vectors([x, y])
y |> Nx.byte_size()
# 16
# if another elements are added to `x`, evaluates to 24, etc.
So I still would be interested if there is a way to get the non-summed gradient, although I understand if it's not possible with this API.
I agree with @jyc in that the grad should have the same vector shape as the output. That is, the correct result for the example should be [1.0, 1.0] instead of 2.0.
The mental model I have is that fun(vectorized[foo: 2] [1, 2]) should yield the same output as vectorize(stack([fun(1), fun(2)]), :foo), which is not the case here.
Memory-wise, vectorization will end up doing the explicit broadcasting, if applicable, regardless of the backend (although some backends might end up fusing things).
Note: This specific comment is wrong and can be ignored; what I said earlier & what polvalente said is correct AFAIK. Sorry for the confusion!
~~@polvalente Thanks for the reply! Sorry but just to be clear, I checked after you mentioned the mental model and it looks like grad returns the same result even without vectorization, so my mentioning the vectorization was a red herring:~~
defmodule Foo do
import Nx.Defn
defn f(x, y) do
x + y
end
defn f_and_grad(x, y) do
value_and_grad(y, fn y -> Foo.f(x, y) end)
end
end
x = ~VEC[0 1 2]
y = ~VEC[1]
Foo.f_and_grad(x, y)
# {~VEC[1, 2, 3], ~VEC[3]}
~~This is still surprising to me but at least it is consistent with and without vectorization. I will keep looking for a workaround.~~
Actually, I have confused myself! I don't believe it's a red herring because it's the other axis that is vectorized. I misunderstood. Please ignore my last comment, sorry for the noise. In other words, I agree with your comment here:
The mental model I have is that fun(vectorized[foo: 2] [1, 2]) should yield the same output as vectorize(stack([fun(1), fun(2)]), :foo), which is not the case here.
The problem here is that for that Foo module, this isn't true:
x = Nx.tensor([0, 1, 2])
y = 1
{_, grad0} = Foo.f_and_grad(x[0], y)
{_, grad1} = Foo.f_and_grad(x[1], y)
{_, grad2} = Foo.f_and_grad(x[2], y)
expected_result = Nx.stack([grad0, grad1, grad2]) |> Nx.vectorize(:foo)
actual_result = Foo.f_and_grad(Nx.vectorize(x, :foo), y)
iex(19)> expected_result = Nx.stack([grad0, grad1, grad2]) |> Nx.vectorize(:foo)
#Nx.Tensor<
vectorized[foo: 3]
f32
[1.0, 1.0, 1.0]
>
iex(20)>
nil
iex(21)> actual_result = Foo.f_and_grad(Nx.vectorize(x, :foo), y)
{#Nx.Tensor<
vectorized[foo: 3]
s32
[1, 2, 3]
>,
#Nx.Tensor<
f32
3.0
>}
You are right! Sorry for the noise.
Reopening because we still need to support vectorize/devectorize inside the gradient. :)
Curious if this behavior is expected given the current vectorization support for gradient?
defmodule NXAutoGradTest do
use ExUnit.Case
import Nx.Defn
defn sum_elems(tensor) do
tensor[0] + tensor[1] + tensor[2]
end
defn sum_elems_grad(tensor) do
grad(tensor, &sum_elems/1)
end
test "basic" do
x = Nx.iota({3, 3})
xv = Nx.vectorize(x, :elements)
IO.inspect(sum_elems(x))
IO.inspect(sum_elems(xv))
IO.inspect(sum_elems_grad(x))
IO.inspect(sum_elems_grad(xv))
end
end
(please excuse the dumb example 😅 )
All print except for the last entry, which yields this error
1) test basic (NXAutoGradTest)
test/nx_autograd_test.exs:16
** (ArgumentError) expected length of axes (3) to match rank of shape (2)
stacktrace:
(nx 0.9.2) lib/nx/shape.ex:238: Nx.Shape.broadcast!/4
(nx 0.9.2) lib/nx.ex:3782: anonymous fn/5 in Nx.broadcast/3
(nx 0.9.2) lib/nx.ex:5408: Nx.apply_vectorized/2
(nx 0.9.2) lib/nx/defn/grad.ex:485: Nx.Defn.Grad.grad/4
(nx 0.9.2) lib/nx/defn/grad.ex:410: Nx.Defn.Grad.update_grads/6
(nx 0.9.2) lib/nx/defn/grad.ex:256: Nx.Defn.Grad.recur_to_grad/4
(elixir 1.16.3) lib/enum.ex:2528: Enum."-reduce/3-lists^foldl/2-0-"/3
(nx 0.9.2) lib/nx/defn/grad.ex:228: Nx.Defn.Grad.recur_to_grad/4
(elixir 1.16.3) lib/enum.ex:2528: Enum."-reduce/3-lists^foldl/2-0-"/3
(nx 0.9.2) lib/nx/defn/grad.ex:209: Nx.Defn.Grad.to_grad/4
(nx 0.9.2) lib/nx/defn/grad.ex:33: Nx.Defn.Grad.transform/3
(nx 0.9.2) lib/nx/defn.ex:640: anonymous fn/2 in Nx.Defn.grad/2
(nx 0.9.2) lib/nx/defn/compiler.ex:173: Nx.Defn.Compiler.runtime_fun/3
(exla 0.9.1) lib/exla/defn.ex:365: anonymous fn/4 in EXLA.Defn.compile/8
(exla 0.9.1) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
(stdlib 5.2) timer.erl:270: :timer.tc/2
(exla 0.9.1) lib/exla/defn.ex:363: anonymous fn/15 in EXLA.Defn.compile/8
(exla 0.9.1) lib/exla/defn.ex:229: EXLA.Defn.__compile__/4
(exla 0.9.1) lib/exla/defn.ex:219: EXLA.Defn.__jit__/5
(nx 0.9.2) lib/nx/defn.ex:452: Nx.Defn.do_jit_apply/3
And I find myself running into this error when having a vectorized axis in the input provided tensor to the gradient, generally speaking.
This is definitely part of the issue at hand. Unfortunately grad support for vectorized tensors doesn't cover all cases properly.
1) test basic (NXAutoGradTest) test/nx_autograd_test.exs:16 ** (ArgumentError) expected length of axes (3) to match rank of shape (2) stacktrace: (nx 0.9.2) lib/nx/shape.ex:238: Nx.Shape.broadcast!/4 (nx 0.9.2) lib/nx.ex:3782: anonymous fn/5 in Nx.broadcast/3
Have the same problem while trying to calculate gradient over vectorized tensors.