deepqmc icon indicating copy to clipboard operation
deepqmc copied to clipboard

Deeperwin builtin ansatz fails to run

Open bkincaid256 opened this issue 4 months ago • 2 comments

Hello,

I was trying out the various ansatz that were available with the prebuilt recipes to see how they differed. While testing, I noticed deeperwin seems to lead to a strange error. My specific case used an atom with an ECP, but I was able to reproduce the issue simply with the default LiH system. To reproduce, one need simply call:

ansatz-testing> HYDRA_FULL_ERROR=1 deepqmc ansatz=deeperwin
[2025-08-07 10:57:58.211] INFO:deepqmc.app: Entering application
[2025-08-07 10:58:01.660] INFO:deepqmc.app: Process 0 running on nid001005
[2025-08-07 10:58:01.660] INFO:deepqmc.app: Running on 4 NVIDIA A100-SXM4-40GBs with 1 process
[2025-08-07 10:58:01.660] INFO:deepqmc.app: Will work in /pscratch/sd/b/bkincaid/DeepQMC/F-tests/ansatz-testing/outputs/2025-08-07/10-57-58
[2025-08-07 10:58:01.953] DEBUG:deepqmc.app: Running with code version: bc9b15a
[2025-08-07 10:58:08.025] DEBUG:deepqmc.train: Setting up metric_logger...
[2025-08-07 10:58:08.037] DEBUG:deepqmc.train: Setting up h5_logger...
[2025-08-07 10:58:18.485] INFO:deepqmc.wf.base: Number of model parameters: 788292
[2025-08-07 10:58:19.020] INFO:deepqmc.train: Pretraining wrt. baseline wave function
[2025-08-07 10:58:19.316] INFO:deepqmc.pretrain.pyscfext: Running HF...
[2025-08-07 10:58:19.651] INFO:deepqmc.pretrain.pyscfext: HF energy: -7.9519715386625505
[2025-08-07 10:58:19.651] INFO:deepqmc.pretrain.pyscfext: Dump PySCF checkpoint to /pscratch/sd/b/bkincaid/DeepQMC/F-tests/ansatz-testing/outputs/2025-08-07/10-57-58/training/pyscf_chkpts/mol_0.pyscf_chkpt
pretrain: 100%|███████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00,  5.76it/s, MSE=(1.03e-02)]
[2025-08-07 10:58:43.968] INFO:deepqmc.train: Pretraining completed with MSE = (1.03e-02)
[2025-08-07 10:58:50.104] INFO:deepqmc.train: Equilibrating sampler...
equilibrate sampler:   5%|███▎                                                               | 50/1000 [00:07<02:15,  7.02it/s, tau=(0.191)]
[2025-08-07 10:58:57.641] INFO:deepqmc.train: Start training
training:   0%|                                                                                                    | 0/1000 [00:12<?, ?it/s]
Error executing job with overrides: ['ansatz=deeperwin']
Traceback (most recent call last):
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/bin/deepqmc", line 8, in <module>
    sys.exit(cli())
             ~~~^^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
    ~~~~~~~~~~^
        args=args,
        ^^^^^^^^^^
    ...<3 lines>...
        config_name=config_name,
        ^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
    ~~~~~~~~^
        run=args.run,
        ^^^^^^^^^^^^^
    ...<5 lines>...
        overrides=overrides,
        ^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
    ~~~~~~~~~~~~~~^
        lambda: hydra.run(
        ^^^^^^^^^^^^^^^^^^
    ...<3 lines>...
        )
        ^
    )
    ^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
            ~~~~~~~~~^
        config_name=config_name,
        ^^^^^^^^^^^^^^^^^^^^^^^^
        task_function=task_function,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        overrides=overrides,
        ^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
        ^^^^^^^^^^^^^^^^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
                       ~~~~~~~~~~~~~^^^^^^^^^^
  File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/app.py", line 194, in cli
    raise e.__cause__ from None  # type: ignore
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 92, in _call_target
    return _target_(*args, **kwargs)
  File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/app.py", line 71, in train_from_factories
    return train(hamil, ansatz, **kwargs)
  File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/train.py", line 305, in train
    ) in fit_wf(  # noqa: B007
         ~~~~~~^^^^^^^^^^^^^^^
        rng,
        ^^^^
    ...<11 lines>...
        ],
        ^^
    ):
    ^
  File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/fit.py", line 120, in fit_wf
    train_state, mol_idxs, stats = train_step(rng, step, data, train_state)
                                   ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/fit.py", line 69, in train_step
    params, opt_state, E_loc, ratios, stats = opt.step(
                                              ~~~~~~~~^
        rng_kfac,
        ^^^^^^^^^
    ...<2 lines>...
        (phys_conf, weight, data),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/global/u1/b/bkincaid/DeepQMC/deepqmc/src/deepqmc/optimizer.py", line 167, in step
    params_list, opt_state, opt_stats = self.kfac.step(
                                        ~~~~~~~~~~~~~~^
        self.pmap_tree_unstack(params),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<3 lines>...
        momentum=0,
        ^^^^^^^^^^^
    )
    ^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/optimizer.py", line 1336, in step
    return self._step(
           ~~~~~~~~~~^
        params, state, rng, batch, func_state, learning_rate, momentum, damping,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        should_update_estimate_curvature, should_update_damping)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/utils/staging.py", line 297, in decorated
    outs = func(instance, *args, **kwargs)
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/utils/misc.py", line 342, in wrapped
    return method(instance, *args, **kwargs)
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/optimizer.py", line 1087, in _step
    state.estimator_state = self._update_estimator_curvature(
                            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        state.estimator_state,
        ^^^^^^^^^^^^^^^^^^^^^^
    ...<5 lines>...
        sync=self.should_sync_estimator(state),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/optimizer.py", line 741, in _update_estimator_curvature
    state = self.estimator.update_curvature_matrix_estimate(
        state=estimator_state,
    ...<6 lines>...
        func_args=func_args,
    )
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/utils/misc.py", line 342, in wrapped
    return method(instance, *args, **kwargs)
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/curvature_estimator/block_diagonal.py", line 752, in update_curvature_matrix_estimate
    state = self._update_blocks(
        losses_vjp(tuple(vjp_vec)),
    ...<4 lines>...
        batch_size=batch_size,
    )
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/curvature_estimator/block_diagonal.py", line 557, in _update_blocks
    block.update_curvature_matrix_estimate(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        block_state,
        ^^^^^^^^^^^^
    ...<4 lines>...
        batch_size=batch_size,
        ^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/utils/misc.py", line 342, in wrapped
    return method(instance, *args, **kwargs)
  File "/global/homes/b/bkincaid/.conda/envs/deepqmc/lib/python3.13/site-packages/kfac_jax/_src/curvature_blocks/kronecker_factored.py", line 449, in update_curvature_matrix_estimate
    assert utils.first_dim_is_size(batch_size, x, dy)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
AssertionError

All the other ansatz seemed to work out of the box for me, so this seems to be the only one affected.

Sincerely, Ben Kincaid

bkincaid256 avatar Aug 07 '25 18:08 bkincaid256