v2 problems (NetKet)
@wesselb I just tried running the new plum with netket, and I'm seeing some issues... Do you have some intuition at what might be going wrong here?
(to reproduce, in case, you have to run)
pip install netket
pip install --upgrade plum-dispatch
wget -qO- https://raw.githubusercontent.com/netket/netket/master/Examples/Ising1d/ising1d.py | python -
/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/vqs/mc/mc_state/state.py:58: UserWarning: n_samples=1000 (1000 per MPI rank) does not divide n_chains=16, increased to 1008 (1008 per MPI rank)
warnings.warn(
0%| | 0/300 [00:00<?, ?it/s]
Traceback (most recent call last):
File "<stdin>", line 41, in <module>
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/driver/abstract_variational_driver.py", line 252, in run
for step in self.iter(n_iter, step_size):
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/driver/abstract_variational_driver.py", line 168, in iter
dp = self._forward_and_backward()
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/driver/vmc.py", line 132, in _forward_and_backward
self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/vqs/mc/mc_state/state.py", line 595, in expect_and_grad
return expect_and_grad(
File "/Users/filippovicentini/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/plum/function.py", line 342, in __call__
self._resolve_pending_registrations()
File "/Users/filippovicentini/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/plum/function.py", line 220, in _resolve_pending_registrations
signature = extract_signature(f, precedence=precedence)
File "/Users/filippovicentini/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/plum/signature.py", line 187, in extract_signature
for k, v in typing.get_type_hints(f).items():
File "/Users/filippovicentini/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py", line 1871, in get_type_hints
value = _eval_type(value, globalns, localns)
File "/Users/filippovicentini/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py", line 329, in _eval_type
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
File "/Users/filippovicentini/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py", line 329, in <genexpr>
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
File "/Users/filippovicentini/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py", line 327, in _eval_type
return t._evaluate(globalns, localns, recursive_guard)
File "/Users/filippovicentini/.pyenv/versions/3.10.6/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/typing.py", line 694, in _evaluate
eval(self.__forward_code__, globalns, localns),
File "<string>", line 1, in <module>
NameError: name 'DenyList' is not defined```
Reduced to MWE in #81 I'll keep this issue open for other bugs that might arise
@PhilipVinc, if I also import DenyList wherever CollectionFilter is imported, your example seems to run. Specifically, these additions seem to work:
vqs/mc/mc_state/expect_forces.py
19:from flax.core.scope import CollectionFilter, DenyList
vqs/mc/mc_state/expect_forces_chunked.py
20:from flax.core.scope import CollectionFilter, DenyList
vqs/mc/mc_state/expect_grad.py
19:from flax.core.scope import CollectionFilter, DenyList
vqs/mc/mc_state/expect_grad_chunked.py
19:from flax.core.scope import CollectionFilter, DenyList
vqs/mc/mc_state/state.py
23:from flax.core.scope import CollectionFilter, DenyList
vqs/exact/expect.py
19:from flax.core.scope import CollectionFilter, DenyList
vqs/exact/state.py
22:from flax.core.scope import CollectionFilter, DenyList
vqs/base.py
24:from flax.core.scope import CollectionFilter, DenyList
The reason that typing.get_type_hints is failing to resolve the forward reference is, I think, because there seems to be no way to trace it back to DenyList. I.e., given CollectionFilter, I'm not sure if there is enough information contained in this type that can trace back the forward reference to DenyList in flax.core.scope. One fix which might not be totally unreasonable would to manually resolve the forward reference:
CollectionFilter.__args__ = CollectionFilter.__args__[:-1] + (DenyList,)
Ideally, this would be done in flax.core.scope...
(sorry, disregard my previous message, it was wrong..)
I think I understand a bit the problem...
So the issue is that if you call typing.get_type_hint() on an object defined inside flax.core.scope.*** get_type_hint will correctly use the local scope to resolve DenyList, as it is defined in this file.
But if you are importing Filter elsewhere, and use it in another file, the fact that Filter is originally defined in another file and therefore might contain some forward references defined in another file is lost, and you get the error...
@PhilipVinc Yes, that’s my understanding of what’s happening. I’m not 100% sure because I haven’t yet looked in detail at the logic of get_type_hints, but that is my guess.
Hey @PhilipVinc! Did you manage to make any headway with this problem? Does adding all those from flax.core.scope import CollectionFilter, DenyList imports everywhere help?