namedtensor
namedtensor copied to clipboard
Reporting error messages with dimension indices
I'm not sure if this is the best place to ask questions, so please let me know if I should ask this elsewhere.
Since namedtensor sits on top of PyTorch tensors and uses pytorch ops, what is the plan to handle error messages that explicitly refer to dimensions by their index?
I couldn't find an example of a function that does this right now (select
would but it is not implemented), but consider index_select
:
In [11]: from namedtensor import ntorch
...: x = ntorch.randn(3, 3, names=['a', 'b'])
...: y = ntorch.randn(4, 4, names=['b', 'c'])
...: z = ntorch.tensor(0, names=[])
...: w = ntorch.tensor(6, names=[])
...:
In [12]: w
Out[12]:
NamedTensor(
tensor(6),
())
In [13]: x.index_select('a', w)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-13-7e2d8bd133cd> in <module>()
----> 1 x.index_select('a', w)
~/local/miniconda3/lib/python3.6/site-packages/namedtensor-0.0.2-py3.6.egg/namedtensor/torch_helpers.py in inde
x_select(self, dim, index)
20 return NamedTensor(
21 self._tensor.index_select(
---> 22 self._schema.get(name), index._tensor.view(-1)
23 ).view(*sizes),
24 new_names,
RuntimeError: index out of range at /Users/soumith/mc3build/conda-bld/pytorch_1549310147607/work/aten/src/TH/ge
neric/THTensorEvenMoreMath.cpp:191
The error message comes from PyTorch. This error message could be improved on the PyTorch side to say "index 6 out of range for dimension 1 with size 3", but this would be jarring to namedtensor users because the error message would report the dim as "dimension 1" and not as the named dimension.
What a fun question, I hadn't thought of this at all. What if we made a helper wrapper that caught RuntimeError's and "find/replaced" => "dimension \d" with the name?
That could work, though I'm wondering if there are any error messages that refer to the dimensions of multiple tensors. torch.mm
is an example, but it isn't implemented in namedtensor (there is namedtensor.dot which handles the case nicely already).
For cases like MM hopefully we can handle the error before it hits Torch. Indexing cases are a bit harder, although maybe we can check range values in the wrapper.