cobaya
cobaya copied to clipboard
Enhance type checking
The method validate_info of CobayaComponent, checks only bool.
This PR is enhancing that to check every relevant type, including generic types (List[], Dict[], Tuple[], etc).
The new code is raising a TypeError when max_samples="bad_value", however, the test_mcmc.py (MPI case) is still breaking as if it is not catching that.
Do you have an idea why this could be happening?
:warning: Please install the to ensure uploads and comments are reliably processed by Codecov.
Codecov Report
Attention: Patch coverage is 26.31579% with 42 lines in your changes missing coverage. Please review.
Project coverage is 74.25%. Comparing base (
735f7a8) to head (19e4143).
| Files with missing lines | Patch % | Lines |
|---|---|---|
| cobaya/component.py | 22.64% | 41 Missing :warning: |
| cobaya/samplers/polychord/polychord.py | 0.00% | 1 Missing :warning: |
:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@ Coverage Diff @@
## master #382 +/- ##
==========================================
- Coverage 74.57% 74.25% -0.33%
==========================================
Files 147 147
Lines 11200 11247 +47
==========================================
- Hits 8352 8351 -1
- Misses 2848 2896 +48
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
I left the validation of bools as it was and added an enforce_types attribute that will trigger the new code.
In this way, one can force type checking by setting enforce_types=True in any descendent class of CobayaComponent, without touching the old validation.
For assessing the robustness of this, I am not sure how to test it. An idea would be to switch the default of _enforce_types to True, let tests here run and, if successful, set it back to False. So at least we know that everything internal is working as expected. What do you think?
The best I can come up with that works with empty_dict, Sequence, Tuple[float] and TypedDicts, and allows numpy arrays for Sequence[float] and Tuple[float], is something like this:
def validate_info(self, name: str, value: Any, annotations: dict):
print(annotations)
if name in annotations:
expected_type = annotations[name]
print(name, expected_type)
if not self._validate_type(expected_type, value):
msg = f"Attribute '{name}' must be of type {expected_type}, not {type(value)}(value={value})"
raise TypeError(msg)
def _validate_composite_type(self, expected_type, value):
origin = expected_type.__origin__
try: # for Callable and Sequence types, which have no __args__
args = expected_type.__args__
except AttributeError:
pass
if origin is Union:
return any(self._validate_type(t, value) for t in args)
elif origin is Optional:
return value is None or self._validate_type(args[0], value)
elif issubclass(origin, Sequence) and isinstance(value, Iterable) and len(args)==1:
return all(self._validate_type(args[0], item) for item in value)
elif issubclass(origin, Sequence):
return isinstance(value, Sequence) and len(args) == len(value) and all(
self._validate_type(t, v) for t, v in zip(args, value)
)
elif origin is dict:
return isinstance(value, Mapping) and all(
self._validate_type(args[0], k) and self._validate_type(args[1], v)
for k, v in value.items()
)
elif origin is ClassVar:
return self._validate_type(args[0], value)
else:
return isinstance(value, origin)
def _validate_type(self, expected_type, value):
if value is None or expected_type is Any: # Any is always valid
return True
if hasattr(expected_type, "__origin__"):
return self._validate_composite_type(expected_type, value)
else:
print(expected_type, value)
# Exceptions for some types
if is_typeddict(expected_type):
type_hints = get_type_hints(expected_type)
if not isinstance(value, Mapping) or not set(value.keys()).issubset(set(type_hints.keys())):
return False
for key, value in value.items():
self.validate_info(key, value, type_hints)
return True
elif expected_type is int:
return value == float('inf') or isinstance(value, Integral)
elif expected_type is float:
return isinstance(value, Real) or isinstance(value, np.ndarray) and not value.ndim
elif expected_type is NumberWithUnits:
return isinstance(value, (Real, str))
return isinstance(value, expected_type)
def validate_attributes(self):
annotations = self.get_annotations()
for name in annotations.keys():
self.validate_info(name, getattr(self, name, None), annotations)
However, is_typeddict is only in core typing from 3.10.
My attempt to generalize and refactor this a bit is now in https://github.com/CobayaSampler/cobaya/pull/388. @ggalloni did you have an SOLikeT build to test against? Anything missed?
Hello @cmbant, thanks for your help with this! Yes, I was using SOLikeT/#192 to test this, so it should be sufficient to point it to the new branch of #388. I guess that would also tell us if something is missing since it was passing all tests using #382 instead.
OK great, let me know any probs. I also just pushed change to hopefully also make it work with deferred types.
Currently, all non-WIndows builds are failing due to CCL not building correctly... Still, Windows is passing all tests, which is reassuring :+1:
Except that you don't have _enforce_types=True, only enforce_types...
I fixed that (I thought I already did...) and am getting an error handling ClassVar.
This seems to happen because that is dealt with only if origin and args are defined for the expected_type.
Instead, some checks skip all that part and produce an error at line 248 of typing.py when trying to execute
isinstance(value, typing.ClassVar)
Can you give specific example?
I made a fix, looks like running OK on windows
I merged, thanks!