CLMM
CLMM copied to clipboard
Improve validation tests for numpy.ndarray
When the type of a variable is numpy.ndarray (usually from something like np.array(2)), the tests in validate_argument fail.
This could be fixed by improving:
def _is_valid(arg, valid_type):
if valid_type == "function":
return callable(arg)
if (valid_type in ("int_array", "float_array") and np.iterable(arg)):
return isinstance(arg[0], _valid_types[valid_type])
if isinstance(arg, np.ndarray):
if (valid_type in (int, "int_array")):
return arg.dtype.char in np.typecodes['AllInteger']
if (valid_type in (float, "float_array")):
return arg.dtype.char in np.typecodes['AllFloat']
return False
return isinstance(arg, _valid_types.get(valid_type, valid_type))
Will isinstance(obj, collections.abc.Iterable) help? It behaves differently from np.iterable(obj) for the 0-d array case. See https://numpy.org/doc/stable/reference/generated/numpy.iterable.html.
@hsinfan1996, thanks for the suggestion. But I think there is an even simpler solution:
def _is_valid(arg, valid_type):
if valid_type == "function":
return callable(arg)
if valid_type == "int_array":
return np.array(arg).dtype.char in np.typecodes['AllInteger']
if valid_type == "float_array":
return np.array(arg).dtype.char in np.typecodes['AllFloat']+np.typecodes['AllInteger']
return isinstance(arg, _valid_types.get(valid_type, valid_type))