mypy
mypy copied to clipboard
"lax_numpy.float64" is not a valid type in python3.8
Bug Report
Doesn't seem possible to use jnp.numpy64 as a type hint in python 3.8 (works in python3.9 and 3.10).
See https://github.com/EmmanuelMess/ConstraintBasedSimulator/actions/runs/7805111999/job/21288560174
To Reproduce
First file that throws the error:
from typing import Callable, Tuple
import jax
import jax.numpy as jnp
from jax import jacfwd, grad, value_and_grad
class ConstraintFunctions:
@staticmethod
def computeDerivatives(
constraintTime: Callable[[jnp.float64, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict], jnp.float64]) \
-> Tuple[
Callable[[jnp.float64, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict], jnp.float64],
Callable[[jnp.float64, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict], jnp.float64],
Callable[[jnp.float64, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict], jnp.float64]
]:
"""
:param constraintTime: A function of time that can be derived on t=0 to obtain the constraints, the second
parameter is getParticleMatrix() and the third is getArgs(). The function passed should be pure and precompiled
"""
constraintAndDerivativeOfTime = jax.jit(value_and_grad(constraintTime, argnums=0))
dConstraint = jax.jit(jacfwd(constraintTime, argnums=1))
d2Constraint = jax.jit(jacfwd(grad(constraintTime, argnums=0), argnums=1))
return constraintAndDerivativeOfTime, dConstraint, d2Constraint
Expected Behavior
The same check that passes for python3.9, should pass for python3.8
Actual Behavior
Run mypy --disallow-untyped-defs constraint_based_simulator tests
constraint_based_simulator/simulator/constraints/functions/ConstraintFunctions.py:11: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/functions/ConstraintFunctions.py:11: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/functions/ConstraintFunctions.py:12: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/functions/ConstraintFunctions.py:12: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/Constraint.py:17: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/Constraint.py:17: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/Constraint.py:19: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/Constraint.py:19: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/Constraint.py:20: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/Constraint.py:20: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/Constraint.py:42: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/Constraint.py:42: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/functions/DistanceConstraintFunctions.py:16: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/functions/DistanceConstraintFunctions.py:16: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/functions/CircleConstraintFunctions.py:11: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/functions/CircleConstraintFunctions.py:11: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/functions/CircleConstraintFunctions.py:16: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type [valid-type]
constraint_based_simulator/simulator/constraints/functions/CircleConstraintFunctions.py:16: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
Found 9 errors in 4 files (checked 82 source files)
Your Environment
- Mypy version used: 1.8.0
- Mypy command-line flags: --disallow-untyped-defs constraint_based_simulator tests
- Mypy configuration options from
mypy.ini(and other config files): - Python version used: 3.8