deepqmc
deepqmc copied to clipboard
Deeperwin builtin ansatz fails to run
Hello,
I was trying out the various ansatz that were available with the prebuilt recipes to see how they differed. While testing, I noticed deeperwin seems to lead to a strange error. My specific case used an atom with an ECP, but I was able to reproduce the issue simply with the default LiH system. To reproduce, one need simply call:
ansatz-testing> HYDRA_FULL_ERROR=1 deepqmc ansatz=deeperwin
[2025-08-07 10:57:58.211] INFO:deepqmc.app: Entering application
[2025-08-07 10:58:01.660] INFO:deepqmc.app: Process 0 running on nid001005
[2025-08-07 10:58:01.660] INFO:deepqmc.app: Running on 4 NVIDIA A100-SXM4-40GBs with 1 process
[2025-08-07 10:58:01.660] INFO:deepqmc.app: Will work in /pscratch/sd/b/bkincaid/DeepQMC/F-tests/ansatz-testing/outputs/2025-08-07/10-57-58
[2025-08-07 10:58:01.953] DEBUG:deepqmc.app: Running with code version: bc9b15a
[2025-08-07 10:58:08.025] DEBUG:deepqmc.train: Setting up metric_logger...
[2025-08-07 10:58:08.037] DEBUG:deepqmc.train: Setting up h5_logger...
[2025-08-07 10:58:18.485] INFO:deepqmc.wf.base: Number of model parameters: 788292
[2025-08-07 10:58:19.020] INFO:deepqmc.train: Pretraining wrt. baseline wave function
[2025-08-07 10:58:19.316] INFO:deepqmc.pretrain.pyscfext: Running HF...
[2025-08-07 10:58:19.651] INFO:deepqmc.pretrain.pyscfext: HF energy: -7.9519715386625505
[2025-08-07 10:58:19.651] INFO:deepqmc.pretrain.pyscfext: Dump PySCF checkpoint to /pscratch/sd/b/bkincaid/DeepQMC/F-tests/ansatz-testing/outputs/2025-08-07/10-57-58/training/pyscf_chkpts/mol_0.pyscf_chkpt
pretrain: 100%|███████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00, 5.76it/s, MSE=(1.03e-02)]
[2025-08-07 10:58:43.968] INFO:deepqmc.train: Pretraining completed with MSE = (1.03e-02)
[2025-08-07 10:58:50.104] INFO:deepqmc.train: Equilibrating sampler...
equilibrate sampler: 5%|███▎ | 50/1000 [00:07<02:15, 7.02it/s, tau=(0.191)]
[2025-08-07 10:58:57.641] INFO:deepqmc.train: Start training
training: 0%| | 0/1000 [00:12<?, ?it/s]
Error executing job with overrides: ['ansatz=deeperwin']
Traceback (most recent call last):
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/bin/deepqmc", line 8, in <module>
sys.exit(cli())
~~~^^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
~~~~~~~~~~^
args=args,
^^^^^^^^^^
...<3 lines>...
config_name=config_name,
^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
~~~~~~~~^
run=args.run,
^^^^^^^^^^^^^
...<5 lines>...
overrides=overrides,
^^^^^^^^^^^^^^^^^^^^
)
^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
~~~~~~~~~~~~~~^
lambda: hydra.run(
^^^^^^^^^^^^^^^^^^
...<3 lines>...
)
^
)
^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
lambda: hydra.run(
~~~~~~~~~^
config_name=config_name,
^^^^^^^^^^^^^^^^^^^^^^^^
task_function=task_function,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
overrides=overrides,
^^^^^^^^^^^^^^^^^^^^
)
^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
^^^^^^^^^^^^^^^^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
~~~~~~~~~~~~~^^^^^^^^^^
File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/app.py", line 194, in cli
raise e.__cause__ from None # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 92, in _call_target
return _target_(*args, **kwargs)
File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/app.py", line 71, in train_from_factories
return train(hamil, ansatz, **kwargs)
File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/train.py", line 305, in train
) in fit_wf( # noqa: B007
~~~~~~^^^^^^^^^^^^^^^
rng,
^^^^
...<11 lines>...
],
^^
):
^
File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/fit.py", line 120, in fit_wf
train_state, mol_idxs, stats = train_step(rng, step, data, train_state)
~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/fit.py", line 69, in train_step
params, opt_state, E_loc, ratios, stats = opt.step(
~~~~~~~~^
rng_kfac,
^^^^^^^^^
...<2 lines>...
(phys_conf, weight, data),
^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/optimizer.py", line 167, in step
params_list, opt_state, opt_stats = self.kfac.step(
~~~~~~~~~~~~~~^
self.pmap_tree_unstack(params),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<3 lines>...
momentum=0,
^^^^^^^^^^^
)
^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/optimizer.py", line 1336, in step
return self._step(
~~~~~~~~~~^
params, state, rng, batch, func_state, learning_rate, momentum, damping,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
should_update_estimate_curvature, should_update_damping)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/utils/staging.py", line 297, in decorated
outs = func(instance, *args, **kwargs)
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/utils/misc.py", line 342, in wrapped
return method(instance, *args, **kwargs)
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/optimizer.py", line 1087, in _step
state.estimator_state = self._update_estimator_curvature(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
state.estimator_state,
^^^^^^^^^^^^^^^^^^^^^^
...<5 lines>...
sync=self.should_sync_estimator(state),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/optimizer.py", line 741, in _update_estimator_curvature
state = self.estimator.update_curvature_matrix_estimate(
state=estimator_state,
...<6 lines>...
func_args=func_args,
)
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/utils/misc.py", line 342, in wrapped
return method(instance, *args, **kwargs)
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/curvature_estimator/block_diagonal.py", line 752, in update_curvature_matrix_estimate
state = self._update_blocks(
losses_vjp(tuple(vjp_vec)),
...<4 lines>...
batch_size=batch_size,
)
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/curvature_estimator/block_diagonal.py", line 557, in _update_blocks
block.update_curvature_matrix_estimate(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
block_state,
^^^^^^^^^^^^
...<4 lines>...
batch_size=batch_size,
^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/utils/misc.py", line 342, in wrapped
return method(instance, *args, **kwargs)
File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/curvature_blocks/kronecker_factored.py", line 449, in update_curvature_matrix_estimate
assert utils.first_dim_is_size(batch_size, x, dy)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
AssertionError
All the other ansatz seemed to work out of the box for me, so this seems to be the only one affected.
Sincerely, Ben Kincaid