nx
nx copied to clipboard
Support :axes option properly in Nx.broadcast
Example of failing code:
iex(1)> t = Nx.tensor([1, 2, 3])
#Nx.Tensor<
s64[3]
Torchx.Backend(cpu)
[1, 2, 3]
>
iex(2)> Nx.broadcast(t, {2, 3, 2}, axes: [1], names: [:x, :y, :z])
** (RuntimeError) Torchx: The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 2. Target sizes: [2, 3, 2]. Tensor sizes: [3] in NIF.broadcast_to/2
(torchx 0.1.0-dev) lib/torchx.ex:343: Torchx.unwrap!/1
(torchx 0.1.0-dev) lib/torchx.ex:346: Torchx.unwrap_tensor!/2
(torchx 0.1.0-dev) lib/torchx/backend.ex:220: Torchx.Backend.broadcast/4