netket icon indicating copy to clipboard operation
netket copied to clipboard

JAX-accelerated Ising Hamiltonian

Open wdphy16 opened this issue 3 years ago • 20 comments

This is based on #968 . Thanks a lot @inailuig !

The design will be discussed in #1342

wdphy16 avatar Sep 23 '22 09:09 wdphy16

this is great, but I am bit worried about generality here. It's clear that we should try to implement LocalOperator (hard) or PauliString (easier) in JAX and then specialize them

gcarleo avatar Sep 23 '22 09:09 gcarleo

It's straightforward to support Hamiltonians like a ZZ + b (XX + YY) + c X, but we need to write a lot more if we want to support a more general one like a XX + b YY + c ZZ + d X + e Y + f Z. For more than two-body interactions, I don't think we can implement it generally in JAX. Also, we need to manually handle a lot of special cases like h = 0 for best speed.

wdphy16 avatar Sep 23 '22 09:09 wdphy16

I would understand this better though, I believe PauliString is relatively straightforward, but maybe I am missing something. @PhilipVinc ?

gcarleo avatar Sep 23 '22 09:09 gcarleo

Codecov Report

Merging #1335 (0cbf37c) into master (5103f6b) will increase coverage by 0.05%. The diff coverage is 93.84%.

@@            Coverage Diff             @@
##           master    #1335      +/-   ##
==========================================
+ Coverage   83.39%   83.45%   +0.05%     
==========================================
  Files         210      210              
  Lines       12986    13041      +55     
  Branches     2008     2014       +6     
==========================================
+ Hits        10830    10883      +53     
- Misses       1659     1660       +1     
- Partials      497      498       +1     
Impacted Files Coverage Δ
netket/operator/_discrete_operator.py 67.36% <0.00%> (ø)
netket/operator/_ising.py 86.02% <95.16%> (+6.76%) :arrow_up:
netket/operator/_local_operator_helpers.py 94.28% <100.00%> (+0.05%) :arrow_up:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov[bot] avatar Sep 23 '22 10:09 codecov[bot]

I had started porting paulistrings to jax some time ago.... It is using the existing numba code to precompute some masks in the constructor, but the operator itself is in jax. See e.g. this commit https://github.com/inailuig/netket/commit/c28645de8cd8a0eb326089f442a6756e36ef8382

inailuig avatar Sep 23 '22 10:09 inailuig

Thanks, now I see how to do PauliStrings in general. I'll try to put it together with what I've written (like adding the tests), and I'll make a separate PR for it

wdphy16 avatar Sep 23 '22 12:09 wdphy16

I would understand this better though, I believe PauliString is relatively straightforward, but maybe I am missing something. @PhilipVinc ?

We cannot really implement efficiently LocalOperators in Jax right now. Or, we can, but then we lose the ability to use get_conn_flattened which cannot be implemented in Jax right now.

While get conn flattened is not used in netket now, I have it in some personal code that considerably reduces the computational cost with complex hamiltonians. Bugs in Jax have been fixed so I could contribute it to NetKet... Mainly someone should propose a way to minimise recompilation if we use flattened (because the shape would change all the time, so we'd need to pad a bit, but how much? with what rule? where is it specified in the API?)

I am against writing LocalOperators in jax until we can support the above, because it really makes a difference in complex calculations. (Or, we should have a way to have both) Once jax will support dynamic shapes, we can implement local operators and keep get conn flattened.

PhilipVinc avatar Sep 23 '22 13:09 PhilipVinc

Wouldn't it make sense to write a fully performant PauliString and then a generic wrapper converting LocalOperator into PauliString?

On Fri, Sep 23, 2022, 15:34 Filippo Vicentini @.***> wrote:

I would understand this better though, I believe PauliString is relatively straightforward, but maybe I am missing something. @PhilipVinc https://github.com/PhilipVinc ?

We cannot really implement efficiently LocalOperators in Jax right now. Or, we can, but then we lose the ability to use get_conn_flattened which cannot be implemented in Jax right now.

While get conn flattened is not used in netket now, I have it in some personal code that considerably reduces the computational cost with complex hamiltonians. Bugs in Jax have been fixed so I could contribute it to NetKet... Mainly someone should propose a way to minimise recompilation if we use flattened (because the shape would change all the time, so we'd need to pad a bit, but how much? with what rule? where is it specified in the API?)

I am against writing LocalOperators in jax until we can support the above, because it really makes a difference in complex calculations. (Or, we should have a way to have both) Once jax will support dynamic shapes, we can implement local operators and keep get conn flattened.

— Reply to this email directly, view it on GitHub https://github.com/netket/netket/pull/1335#issuecomment-1256218852, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWYRBEPOOIEHQDU5N7VSS3V7WWU3ANCNFSM6AAAAAAQTZI3FM . You are receiving this because you commented.Message ID: @.***>

gcarleo avatar Sep 23 '22 13:09 gcarleo

Wouldn't it make sense to write a fully performant PauliString and then a generic wrapper converting LocalOperator into PauliString?

What do you mean?

PhilipVinc avatar Sep 23 '22 13:09 PhilipVinc

You can convert an arbitrary LocalOperator Hamiltonian into a PauliString, you just need to map the operators onto Pauli operators and the local space onto qubits

On Fri, Sep 23, 2022, 15:37 Filippo Vicentini @.***> wrote:

Wouldn't it make sense to write a fully performant PauliString and then a generic wrapper converting LocalOperator into PauliString?

What do you mean?

— Reply to this email directly, view it on GitHub https://github.com/netket/netket/pull/1335#issuecomment-1256224326, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWYRBEX6JWIW5NX4LBFQG3V7WXDBANCNFSM6AAAAAAQTZI3FM . You are receiving this because you commented.Message ID: @.***>

gcarleo avatar Sep 23 '22 13:09 gcarleo

Not efficiently, you can't. LocalOperator stores the matrix representation of XX which is a 4x4 matrix, or of ZXYZ which is 16x16, forgetting that the matrix is exponentially large, how do you take the matrix and identify the corresponding paulistring? I think it's an (NP?) hard problem.

To do what you say we'd need to have some tools to keep everything symbolically and do symbolic manipulation. But we don't have those yet.

PhilipVinc avatar Sep 23 '22 13:09 PhilipVinc

This is interesting, I think it can be done efficiently using some matrix décompositions but let's discuss this with a blackboard. I am almost positive this is not NP hard, unless you look for the optimal decomposition maybe, but might not be necessary

On Fri, Sep 23, 2022, 15:43 Filippo Vicentini @.***> wrote:

Not efficiently, you can't. LocalOperator stores the matrix representation of XX which is a 4x4 matrix, or of ZXYZ which is 16x16, forgetting that the matrix is exponentially large, how do you take the matrix and identify the corresponding paulistring? I think it's an (NP?) hard problem.

To do what you say we'd need to have some tools to keep everything symbolically and do symbolic manipulation. But we don't have those yet.

— Reply to this email directly, view it on GitHub https://github.com/netket/netket/pull/1335#issuecomment-1256233988, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWYRBESYMDAWHYL2TNW2LTV7WXZNANCNFSM6AAAAAAQTZI3FM . You are receiving this because you commented.Message ID: @.***>

gcarleo avatar Sep 23 '22 13:09 gcarleo

Still, I don't see why we need this...

I think the actual good thing would be to have some decent Symbolic CAS within netket (we already have two half-sketched ones, both written by @jwnys , for PauliStrings and FermionicOperators2nd). Let's unify them, make them work with Fock spaces as well, and then give a way to convert a SymbolicOperator to LocalOperator or PauliString (maybe even warning if the operator is K-Local with K=6 and is non diagonal that it's going to be very not efficient....

PhilipVinc avatar Sep 23 '22 13:09 PhilipVinc

I agree a synbolyc tool is needs, I am just arguing we maybe don't need to write two hard-coded fully performant Hamiltonians, we just need to write one and convert all we can in that format. Of course the issue is if this can be done efficiently, but I believe it should be possible

On Fri, Sep 23, 2022, 15:51 Filippo Vicentini @.***> wrote:

Still, I don't see why we need this...

I think the actual good thing would be to have some decent Symbolic CAS within netket (we already have two half-sketched ones, both written by @jwnys https://github.com/jwnys , for PauliStrings and FermionicOperators2nd). Let's unify them, make them work with Fock spaces as well, and then give a way to convert a SymbolicOperator to LocalOperator or PauliString (maybe even warning if the operator is K-Local with K=6 and is non diagonal that it's going to be very not efficient....

— Reply to this email directly, view it on GitHub https://github.com/netket/netket/pull/1335#issuecomment-1256243995, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWYRBHQBKQDBYCVS3P6EN3V7WYURANCNFSM6AAAAAAQTZI3FM . You are receiving this because you commented.Message ID: @.***>

gcarleo avatar Sep 23 '22 13:09 gcarleo

I think we can efficiently convert any practical local operator to Pauli strings. Each term in the local operator is a 2^N * 2^N Hermitian matrix with 4^N independent real parameters. To convert it to a sum of some Pauli strings, we first list all 4^N Pauli strings of length N, which form a complete basis for the space of all 2^N * 2^N Hermitian matrices. Then we solve a linear system with 4^N real parameters to determine the coefficients of those Pauli strings. Although 4^N is exponentially large, for really 'local' operators with small N we can still efficiently do it.

If the operator is not Hermitian, there will be 4^N independent complex parameters, and in that case we allow the coefficients of Pauli strings to be complex.

wdphy16 avatar Sep 23 '22 15:09 wdphy16

I did some benchmarks but I cannot see any significant time or memory difference between HeisenbergJax and PauliStringsJax. Maybe we should really move everything to Pauli strings? (But I'm not sure how Pauli strings work with more than 2 local states)

In actual use cases I think the time for evaluating the neural network is always much larger than the time for get_conn_padded (as long as the memory allocation is working normally), and why I wanted the JAX version is mainly because of the memory bottleneck.

wdphy16 avatar Sep 23 '22 15:09 wdphy16

I think we can efficiently convert any practical local operator to Pauli strings. Each term in the local operator is a 2^N * 2^N Hermitian matrix with 4^N independent real parameters. To convert it to a sum of some Pauli strings, we first list all 4^N Pauli strings of length N, which form a complete basis for the space of all 2^N * 2^N Hermitian matrices. Then we solve a linear system with 4^N parameters to determine the coefficients of those Pauli strings. Although 4^N is exponentially large, for really 'local' operators with small N we can still efficiently do it.

Indeed this seems a possible way to go, I wonder if we can further optimize this procedure to remove some strings, but otherwise this is a strategy already usable !

gcarleo avatar Sep 23 '22 15:09 gcarleo

I did some benchmarks but I cannot see any significant time or memory difference between HeisenbergJax and PauliStringsJax. Maybe we should really move everything to Pauli strings?

Maybe let's do some more tests but if we can remove these specialized hamiltonians it would be just much more elegant

gcarleo avatar Sep 23 '22 15:09 gcarleo

Indeed this seems a possible way to go, I wonder if we can further optimize this procedure to remove some strings, but otherwise this is a strategy already usable !

We can optimize the strings (like combining like terms) after solving all the linear systems, which I think Jannes already implemented

wdphy16 avatar Sep 23 '22 16:09 wdphy16

@wdphy16 , of courser I'm in favour of this, but I want it to be implemented correctly.

to better understand the implications of the changes you are trying to make, could you please make a separate PR with only the changes to Ising in it, essentially a cleanup of the original PR by @inailuig?

Also, could you please document how the API would work out for general operators? Would we end up having 1 single Ising or BoseHubbard class in Jax, or two of them (jax and numba)? How does the user pick and switch from one to the other?

Is this flexible enough?

Moreover, In order to keep the in think it would be better to create a new abstract Mixin class DiscreteJaxOperator and all 'jax' operators should inherit from it as well, and it would be useful to see a proposal for what this interface would entail. Also, all Jax-aware operators should be declared as pytrees so that they can be passed as jax arguments.

PhilipVinc avatar Sep 24 '22 10:09 PhilipVinc

I was just curious: what is the outcome then? is there a speed up or a slow down when using this Jax-aware implementation of the Ising hamiltonian?

gcarleo avatar Oct 29 '22 16:10 gcarleo

Ising yes, all the others not.

PhilipVinc avatar Oct 29 '22 16:10 PhilipVinc

all the others you mean Heisenberg is slower ?

gcarleo avatar Oct 29 '22 16:10 gcarleo

yes. The exchange term in the Heisenberg

                [0, 0, 0, 0],
                [0, 0, 2, 0],
                [0, 2, 0, 0],
                [0, 0, 0, 0],

has an dynamical number of connected elements depending on its value. In Numba we manage to prune some of those. In jax we have to go with worst case scenario which leads to quite a bit of wasted computations.

When Jax will give us dynamical shapes (they say by the end of the year) we could get back to Heisenberg...

PhilipVinc avatar Oct 29 '22 16:10 PhilipVinc

In many of my use cases the jax version of Heisenberg is slower by ~20%, so I agree that we wait until jax supports dynamically shaped arrays

wdphy16 avatar Oct 29 '22 16:10 wdphy16

I would still like to merge IsingJax and have it as the default implementation, but Dian refutes my commandments.

PhilipVinc avatar Oct 29 '22 16:10 PhilipVinc

Maybe there is no need to maintain a separate implementation of Ising because we can move everything to Pauli strings

wdphy16 avatar Oct 29 '22 16:10 wdphy16

doing PauliStringJax once possible would be ideal sure

gcarleo avatar Oct 29 '22 16:10 gcarleo

Not you as well... 😭 Dian!

PhilipVinc avatar Oct 29 '22 16:10 PhilipVinc

I think we can efficiently convert any practical local operator to Pauli strings. Each term in the local operator is a 2^N * 2^N Hermitian matrix with 4^N independent real parameters. To convert it to a sum of some Pauli strings, we first list all 4^N Pauli strings of length N, which form a complete basis for the space of all 2^N * 2^N Hermitian matrices. Then we solve a linear system with 4^N real parameters to determine the coefficients of those Pauli strings. Although 4^N is exponentially large, for really 'local' operators with small N we can still efficiently do it.

If the operator is not Hermitian, there will be 4^N independent complex parameters, and in that case we allow the coefficients of Pauli strings to be complex.

I was just reading this. The product of Pauli matrices form an orthogonal basis in the space of 2^n x 2^n, so no need to solve any linear system, you can just project. Here's a sketch of how to obtain the coefficients (from which you can directly obtain the PauliStrings) in the Pauli basis of any matrix of size 2^n x 2^n:

import jax
import jax.numpy as jnp
import numpy as np

I = jnp.eye(2)
X = jnp.array([[0, 1], [1, 0]])
Y = jnp.array([[0, -1j], [1j, 0]])
Z = jnp.array([[1, 0], [0, -1]])

pauli_basis = jnp.stack([I, X, Y, Z], axis=0)

@jax.jit
def tensor_product(mats1, mats2):
    # all combinations between two sets via tensorproduct
    if mats1.ndim == 2 and mats2.ndim == 2:
        return jnp.kron(mats1, mats2)
    elif mats1.ndim == 3:
        return jax.vmap(tensor_product, in_axes=(0, None))(mats1, mats2)
    elif mats2.ndim == 3:
        return jax.vmap(tensor_product, in_axes=(None, 0))(mats1, mats2)

def basis_till_n(n=1):
    # return basis till 2^n x 2^n size
    # output is dict with keys 2^n and values (# basis elements for n) x 2^n x 2^n
    if n == 1:
        return {2 : pauli_basis}
    else:
        lower_bases = basis_till_n(n=n-1)
        tp = tensor_product(pauli_basis, lower_bases[2**(n-1)])
        tp = tp.reshape(-1, *tp.shape[-2:])
        return {**lower_bases, (2**n):tp}
    
@jax.jit
def hilbert_schmidt(A, basis_element):
    # hilbert schmidt product but the basis elements are hermitian (so no basis_element.conj().T here)
    return jnp.trace(basis_element @ A)

@jax.jit
def project_matrix(A, bases):
    # project onto basis element according to hilbert schmitt product
    assert A.ndim == 2
    assert A.shape[0] == A.shape[1]
    size = A.shape[0]
    basis_set = bases[size]
    return jax.vmap(hilbert_schmidt, in_axes=(None, 0))(A, basis_set) / size    

# let's run a test
bases = basis_till_n(n=3)
test_matrix = 1j*np.kron(I, Z)+2*np.kron(Z, Z)
out = project_matrix(test_matrix, bases)
nz = np.where(out)[0]
assert np.allclose(nz, np.array([3, 15])) # check the right bases are there
assert np.allclose(out[nz], np.array([1j, 2]))

So having #1510 would be sufficient to map also LocalOperator to Jax.

jwnys avatar Jun 28 '23 20:06 jwnys