Support value promotion (implicit `Constant`)
This is a feature idea that comes up in alternative libraries. It would be useful to be able to express things like op.add(x, 1) or even op.reshape(x, [-1, 2]) without the need for excess baggage like op.const or even op.constant.
This raises the following questions:
- opset-independence - should this be done with a
Constantoperator or something else? - signatures - should we support
npt.ArrayLike(everythingnumpytakes),npt.NDArray/np.number(explicit typing only), or maybe just Python types? - opset generation - can we handle this in a cleaner way than just generating code that constructs the constant?
- typing - in a case like
op.add(x, 1), should1just becomeint64, or be coerced to whateverxis? Shouldop.add(x, 1.0)have1.0befloat32orfloat64(related: Quantco/spox-foss#24). It is technically possible to implement the coercion with schema info in most cases.
I think it's useful to think about these as value promotion could be important for usability long-term.
A considerable fraction of this use case overlaps with operator overloading. It is also important to remember that this only applies to Var arguments. Attributes already provide this user experience.
An important difference with respect to operator overloading is that constructor functions of the ai.onnx namespace do have a known opset generation. In those cases, we could simply throw every input that is not Var onto op.const.
I would be happy with a non-coercing implementation of this for a start, if only a nice type hint were possible. npt.ArrayLike is a Union and would expand into its whole definition in a rather unwieldy way (similarly to how npt.DTypeLike does right now, for instance in Tensor).
The only workaround is using a type variable, for which names aren't expanded in signatures, but we would need one for every argument.
Maybe this sphinx feature fits the bill? It seems to be used in other projects for this very purpose. E.g. https://github.com/google/jax/blob/main/docs/conf.py#L293-L297
Indeed I think we can fix this in Sphinx - it would also be nice to find something for IDEs/checkers, although: PyCharm doesn't seem to have resolved: https://youtrack.jetbrains.com/issue/PY-42486 And for mypy it seems to be stale though there was a PR last year: https://github.com/python/mypy/issues/2968
I think I would like to avoid impairing user experience but making our signatures almost unreadable if they don't have the right IDE.
The two main workarounds I had in mind don't (really) work - NewType doesn't work on Union types (since you can't subclass from them), and TypeVar would require having a separate one for every argument there (and it would also not work for variadics). I'll try and look into this further as it would also fix the existing issue of messy dtype attribute hints.
This would be a proof-of-concept:
from typing import Union
from typing_extensions import TypeAlias
import numpy as np
import numpy.typing as npt
from spox import argument, Var, Tensor
import spox.opset.ai.onnx.v17 as op
ArrayLike: TypeAlias = npt.ArrayLike
def add(
A: Union[Var, ArrayLike],
B: Union[Var, ArrayLike],
) -> Var:
return op._Add(
op._Add.Attributes(),
op._Add.Inputs(
A=A if isinstance(A, Var) else op.constant(value=np.array(A)),
B=B if isinstance(B, Var) else op.constant(value=np.array(B)),
),
).outputs.C
if __name__ == '__main__':
x, y = argument(Tensor(float, ('N',))), argument(Tensor(float, (None, None)))
print(add(x, y))
print(add(x, 1.))
print(add(x, np.float64(1.)))
# print(add(x, object())) # raises TypeError
The issue I am talking about is that mypy prints this signature:
note: Revealed type is "Union[spox._var.Var, Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], builtins.bool, builtins.int, builtins.float, builtins.complex, builtins.str, builtins.bytes, numpy._typing._nested_sequence._NestedSequence[Union[builtins.bool, builtins.int, builtins.float, builtins.complex, builtins.str, builtins.bytes]]]]"
And PyCharm doesn't do much better:

The type variable workaround would almost work (not for variadics, though), with some slightly more involved generation:
ArrayLike1 = TypeVar('ArrayLike1', bound=npt.ArrayLike)
ArrayLike2 = TypeVar('ArrayLike2', bound=npt.ArrayLike)
def add(
A: Union[Var, ArrayLike1],
B: Union[Var, ArrayLike2],
) -> Var:
...
After this both mypy and PyCharm are forced to only display the type variable name (and the user would have to lookup the definition of the bound, npt.ArrayLike, themselves).