mypy icon indicating copy to clipboard operation
mypy copied to clipboard

"lax_numpy.float64" is not a valid type in python3.8

Open EmmanuelMess opened this issue 1 year ago • 0 comments

Bug Report

Doesn't seem possible to use jnp.numpy64 as a type hint in python 3.8 (works in python3.9 and 3.10).

See https://github.com/EmmanuelMess/ConstraintBasedSimulator/actions/runs/7805111999/job/21288560174

To Reproduce

First file that throws the error:

from typing import Callable, Tuple

import jax
import jax.numpy as jnp
from jax import jacfwd, grad, value_and_grad


class ConstraintFunctions:
    @staticmethod
    def computeDerivatives(
            constraintTime: Callable[[jnp.float64, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict], jnp.float64]) \
            -> Tuple[
                Callable[[jnp.float64, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict], jnp.float64],
                Callable[[jnp.float64, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict], jnp.float64],
                Callable[[jnp.float64, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict], jnp.float64]
            ]:
        """
        :param constraintTime: A function of time that can be derived on t=0 to obtain the constraints, the second
        parameter is getParticleMatrix() and the third is getArgs(). The function passed should be pure and precompiled
        """
        constraintAndDerivativeOfTime = jax.jit(value_and_grad(constraintTime, argnums=0))
        dConstraint = jax.jit(jacfwd(constraintTime, argnums=1))
        d2Constraint = jax.jit(jacfwd(grad(constraintTime, argnums=0), argnums=1))
        return constraintAndDerivativeOfTime, dConstraint, d2Constraint

Expected Behavior

The same check that passes for python3.9, should pass for python3.8

Actual Behavior

Run mypy --disallow-untyped-defs constraint_based_simulator tests
constraint_based_simulator/simulator/constraints/functions/ConstraintFunctions.py:11: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/functions/ConstraintFunctions.py:11: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/functions/ConstraintFunctions.py:12: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/functions/ConstraintFunctions.py:12: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/Constraint.py:17: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/Constraint.py:17: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/Constraint.py:19: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/Constraint.py:19: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/Constraint.py:20: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/Constraint.py:20: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/Constraint.py:42: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/Constraint.py:42: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/functions/DistanceConstraintFunctions.py:16: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/functions/DistanceConstraintFunctions.py:16: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/functions/CircleConstraintFunctions.py:11: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/functions/CircleConstraintFunctions.py:11: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
constraint_based_simulator/simulator/constraints/functions/CircleConstraintFunctions.py:16: error: Variable "jax._src.numpy.lax_numpy.float64" is not valid as a type  [valid-type]
constraint_based_simulator/simulator/constraints/functions/CircleConstraintFunctions.py:16: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
Found 9 errors in 4 files (checked 82 source files)

Your Environment

  • Mypy version used: 1.8.0
  • Mypy command-line flags: --disallow-untyped-defs constraint_based_simulator tests
  • Mypy configuration options from mypy.ini (and other config files):
  • Python version used: 3.8

EmmanuelMess avatar Feb 06 '24 19:02 EmmanuelMess