import hook can poison the `pyc` cache
Hi,
As I understand it, the import hook's modules parameter is used to avoid adding the hook to transitive imports (as in the scipy example here), and also might work to filter submodules (as in the bar.baz example here). But I can't seem to get it to not typecheck some files which I am not ready to typecheck yet. Here's a small example, which has 3 files:
- at root:
go.py -
package_a/left.pyandpackage_a/right.py
Which contain
go.py:
from jaxtyping import install_import_hook
with install_import_hook("package_a.left", "typeguard.typechecked"):
from package_a.left import well_typed_function
print(well_typed_function(10))
package_a/left.py:
from package_a.right import badly_typed_function
import inspect
def well_typed_function(a: int) -> int:
x = badly_typed_function(a)
print(f"badly_typed_function is in {inspect.getmodule(badly_typed_function)}")
return x+1
package_a/right.py:
from typing import List
# This is purposefully incorrect so that runtime type checking will complain.
# It actually returns an int.
def badly_typed_function(a: int) -> List:
return 2*a
When I run this I get:
jaxtyping.TypeCheckError: Type-check error whilst checking the return value of package_a.right.badly_typed_function.
Actual value: 20
Expected type: typing.List.
Turning off the hook gives me:
badly_typed_function is in <module 'package_a.right' from '/home/richard/deep-affinity/tmp/package_a/right.py'>
21
so both jaxtyping and inspect seem to agree that the function is in package_a.right, but we are only asking for import hooks on package_a.left. I'm not sure what's happening here, is this intended (maybe they aren't actually modules because of the lack of __init__.py or something?)? Is there a way to get the prefix check to work?
I did a little more work on this, I monkey-patched to get debug prints:
def make_debug_print(f):
def _new_f(self, fullname, path=None, target=None):
rval = f(self, fullname, path, target)
print(f"fullname is {fullname}, returning {rval}")
return _new_f
jaxtyping._import_hook._JaxtypingFinder.find_spec = make_debug_print(jaxtyping._import_hook._JaxtypingFinder.find_spec)
And I see:
fullname is package_a, returning None
fullname is package_a.left, returning ModuleSpec(name='package_a.left', loader=<jaxtyping._import_hook._JaxtypingLoader object at 0x7965d991fe20>, origin='/home/richard/deep-affinity/tmp/package_a/left.py')
fullname is package_a.right, returning None
so indeed once we visit the package_a.right module, find_spec is correctly returning None. So I am not sure why that function is getting annotated.
This seems strange. Indeed, as I think you've already located, the logic for which modules to typecheck is here:
https://github.com/patrick-kidger/jaxtyping/blob/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/jaxtyping/_import_hook.py#L265
This isn't behaviour we've seen issues with before; if you're able to identify what's going on then I'd be curious to know.
A little progress, I added this to the top of right.py:
print(f"=== LOADING package_a.right ===")
print(f"Current frame: {inspect.currentframe()}")
print(f"Stack trace:")
for frame_info in inspect.stack():
print(f" {frame_info.filename}:{frame_info.lineno} in {frame_info.function}")
# Check how this module is being loaded
module = sys.modules.get(__name__)
if module:
print(f"Module spec: {getattr(module, '__spec__', 'No __spec__')}")
print(f"Module loader: {getattr(module, '__loader__', 'No __loader__')}")
print(f"Module file: {getattr(module, '__file__', 'No __file__')}")
print("=== END DEBUG INFO ===")
and the output is
=== LOADING package_a.right ===
Current frame: <frame at 0x759e34daed90, file '/home/richard/jt_example/package_a/right.py', line 5, code <module>>
Stack trace:
/home/richard/jt_example/package_a/right.py:7 in <module>
<frozen importlib._bootstrap>:241 in _call_with_frames_removed
<frozen importlib._bootstrap_external>:883 in exec_module
<frozen importlib._bootstrap>:688 in _load_unlocked
<frozen importlib._bootstrap>:1006 in _find_and_load_unlocked
<frozen importlib._bootstrap>:1027 in _find_and_load
/home/richard/jt_example/package_a/left.py:1 in <module>
<frozen importlib._bootstrap>:241 in _call_with_frames_removed
<frozen importlib._bootstrap_external>:883 in exec_module
/home/richard/.conda/envs/myenv/lib/python3.10/site-packages/jaxtyping/_import_hook.py:230 in exec_module
<frozen importlib._bootstrap>:688 in _load_unlocked
<frozen importlib._bootstrap>:1006 in _find_and_load_unlocked
<frozen importlib._bootstrap>:1027 in _find_and_load
/home/richard/jt_example/go.py:18 in <module>
Module spec: ModuleSpec(name='package_a.right', loader=<_frozen_importlib_external.SourceFileLoader object at 0x759e34a3b400>, origin='/home/richard/jt_example/package_a/right.py')
Module loader: <_frozen_importlib_external.SourceFileLoader object at 0x759e34a3b400>
Module file: /home/richard/jt_example/package_a/right.py
=== END DEBUG INFO ===
it seems that we are using the vanilla importlib.SourceFileLoader here (good) but still getting the jaxtyping _import_hook.py exec_module. Kind of makes sense, because the patch here:
https://github.com/patrick-kidger/jaxtyping/blob/fe61644be8590cf0a89bacc278283e1e4b9ea3e4/jaxtyping/_import_hook.py#L223
is going to be patched for the whole time we are executing the module package_a.left, including during the import of package_a.right. That means if I ended up with a right.....pyc with the hook on first, it's going to stay in the cache.
Here's a more detailed repro to show you what I mean:
(1) If you run my initial bug report code after a rm -rf package_a/__pycache__ then in fact jaxtyping behaves correctly (no error)!
(2) Then, we do rm -rf package_a/__pycache__ again and change go.py to add typechecking everywhere:
with install_import_hook("package_a", "typeguard.typechecked"): #was install_import_hook("package_a.left")
then we get an error (again, as expected!)
TypeError: type of the return value must be a list; got int instead
(3) Now, without clearing pycache, if we change go.py back to how it was (install_import_hook("package_a.left")) and run, we still get the cached right.py and we get the TypeError above (unexpected behavior here!). If we look at __pycache__ we get:
~/jt_example$ ls package_a/__pycache__
left.cpython-310.opt-jaxtyping918b7c48c88d742569d512b492c160aa9.pyc
right.cpython-310.opt-jaxtyping918b7c48c88d742569d512b492c160aa9.pyc
we are using the pyc with jaxtyping in its name -- left over from step (2) when it was correctly generated. That makes some sense, the file (right.py) hasn't changed.
To summarize a little, I think the issue is that _JaxtypingLoader.exec_module() monkey-patches importlib._bootstrap_external.cache_from_source globally during execution, affecting all module compilations that happen during that time, not just the module being loaded by the _JaxtypingLoader, which can lead to loading from the wrong pyc file.
I'm not totally sure how to fix, maybe a second Loader one down the stack that restores the original cache_from_source?
Ah excellent, good sleuthing! The chain of events makes sense here.
Perhaps we could adjust the monkey-patched _optimized_cache_from_source to additionally consume a list of valid paths that it should apply to, and just use the default behaviour if the provided path isn't used?
Ah excellent, good sleuthing! The chain of events makes sense here.
Thanks!
Perhaps we could adjust the monkey-patched
_optimized_cache_from_sourceto additionally consume a list of valid paths that it should apply to, and just use the default behaviour if the provided path isn't used?
That could work, but I am not sure where to get the paths. I don't think you can go straight from path to module name because of __init__.py shenanigans. We have a path in find_spec, which we could pass to the _JaxtypingLoader constructor. But that path can be None.
Just brainstorming here, another possibility is something like this:
-
We could run
_JaxtypingFindereverywhere, even on modules we don't intend to instrument. Sofind_specnever returnsNone, but instead passes the 'should instrument here' state into the_JaxtypingLoaderit makes. -
_JaxtypingLoaderthen needs two changes: a)_source_to_codebecomes a no-op (doesn't callJaxtypingTransformer) when we don't need to be instrumenting b)exec_modulepatches the originalcache_from_sourceback in if we don't need to be instrumenting. So thepatchcontext managers sort of form a stack that mirrors the module execution stack, if that makes sense.
In the meantime, a workaround (which is perhaps a little dangerous, and only works if you are sure about where these pyc are going to live) is to add this before the hook:
from pathlib import Path
for f in Path(".").rglob("*jaxtyping*.pyc"):
f.unlink()
Actually a perhaps cleaner/safer one is:
import sys
sys.dont_write_bytecode = True
That could work, but I am not sure where to get the paths
From some quick testing, I think this is available as just module.__path__ when inside of _JaxtypingLoader.exec_module.
I think there might still be some gotchas in there, like .../foolib -> .../foolib/init.py(maybe the logic should be around prefix-matching?), and alsopath` not always being available (which perhaps we can just skip? I've not checked too carefully when this might be).
We have a path in find_spec, which we could pass to the _JaxtypingLoader constructor. But that path can be None.
I think this sounds like a reasonable alternative as well, and probably we can just ignore Nones?
Just brainstorming here, another possibility is something like this:
This also sounds totally reasonable. I think this might also be an opportunity to fix a small nit, which is that I think having multiple nested import hooks might result in all but one of them getting ignored. (As the first declares that it can import the module but not instrument it.) An untested claim on my part, just eyeballing the code.
WDYT?