JAX-accelerated Ising Hamiltonian
This is based on #968 . Thanks a lot @inailuig !
The design will be discussed in #1342
this is great, but I am bit worried about generality here. It's clear that we should try to implement LocalOperator (hard) or PauliString (easier) in JAX and then specialize them
It's straightforward to support Hamiltonians like a ZZ + b (XX + YY) + c X, but we need to write a lot more if we want to support a more general one like a XX + b YY + c ZZ + d X + e Y + f Z. For more than two-body interactions, I don't think we can implement it generally in JAX. Also, we need to manually handle a lot of special cases like h = 0 for best speed.
I would understand this better though, I believe PauliString is relatively straightforward, but maybe I am missing something. @PhilipVinc ?
Codecov Report
Merging #1335 (0cbf37c) into master (5103f6b) will increase coverage by
0.05%. The diff coverage is93.84%.
@@ Coverage Diff @@
## master #1335 +/- ##
==========================================
+ Coverage 83.39% 83.45% +0.05%
==========================================
Files 210 210
Lines 12986 13041 +55
Branches 2008 2014 +6
==========================================
+ Hits 10830 10883 +53
- Misses 1659 1660 +1
- Partials 497 498 +1
| Impacted Files | Coverage Δ | |
|---|---|---|
| netket/operator/_discrete_operator.py | 67.36% <0.00%> (ø) |
|
| netket/operator/_ising.py | 86.02% <95.16%> (+6.76%) |
:arrow_up: |
| netket/operator/_local_operator_helpers.py | 94.28% <100.00%> (+0.05%) |
:arrow_up: |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
I had started porting paulistrings to jax some time ago.... It is using the existing numba code to precompute some masks in the constructor, but the operator itself is in jax. See e.g. this commit https://github.com/inailuig/netket/commit/c28645de8cd8a0eb326089f442a6756e36ef8382
Thanks, now I see how to do PauliStrings in general. I'll try to put it together with what I've written (like adding the tests), and I'll make a separate PR for it
I would understand this better though, I believe PauliString is relatively straightforward, but maybe I am missing something. @PhilipVinc ?
We cannot really implement efficiently LocalOperators in Jax right now.
Or, we can, but then we lose the ability to use get_conn_flattened which cannot be implemented in Jax right now.
While get conn flattened is not used in netket now, I have it in some personal code that considerably reduces the computational cost with complex hamiltonians. Bugs in Jax have been fixed so I could contribute it to NetKet... Mainly someone should propose a way to minimise recompilation if we use flattened (because the shape would change all the time, so we'd need to pad a bit, but how much? with what rule? where is it specified in the API?)
I am against writing LocalOperators in jax until we can support the above, because it really makes a difference in complex calculations. (Or, we should have a way to have both) Once jax will support dynamic shapes, we can implement local operators and keep get conn flattened.
Wouldn't it make sense to write a fully performant PauliString and then a generic wrapper converting LocalOperator into PauliString?
On Fri, Sep 23, 2022, 15:34 Filippo Vicentini @.***> wrote:
I would understand this better though, I believe PauliString is relatively straightforward, but maybe I am missing something. @PhilipVinc https://github.com/PhilipVinc ?
We cannot really implement efficiently LocalOperators in Jax right now. Or, we can, but then we lose the ability to use get_conn_flattened which cannot be implemented in Jax right now.
While get conn flattened is not used in netket now, I have it in some personal code that considerably reduces the computational cost with complex hamiltonians. Bugs in Jax have been fixed so I could contribute it to NetKet... Mainly someone should propose a way to minimise recompilation if we use flattened (because the shape would change all the time, so we'd need to pad a bit, but how much? with what rule? where is it specified in the API?)
I am against writing LocalOperators in jax until we can support the above, because it really makes a difference in complex calculations. (Or, we should have a way to have both) Once jax will support dynamic shapes, we can implement local operators and keep get conn flattened.
— Reply to this email directly, view it on GitHub https://github.com/netket/netket/pull/1335#issuecomment-1256218852, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWYRBEPOOIEHQDU5N7VSS3V7WWU3ANCNFSM6AAAAAAQTZI3FM . You are receiving this because you commented.Message ID: @.***>
Wouldn't it make sense to write a fully performant PauliString and then a generic wrapper converting LocalOperator into PauliString?
What do you mean?
You can convert an arbitrary LocalOperator Hamiltonian into a PauliString, you just need to map the operators onto Pauli operators and the local space onto qubits
On Fri, Sep 23, 2022, 15:37 Filippo Vicentini @.***> wrote:
Wouldn't it make sense to write a fully performant PauliString and then a generic wrapper converting LocalOperator into PauliString?
What do you mean?
— Reply to this email directly, view it on GitHub https://github.com/netket/netket/pull/1335#issuecomment-1256224326, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWYRBEX6JWIW5NX4LBFQG3V7WXDBANCNFSM6AAAAAAQTZI3FM . You are receiving this because you commented.Message ID: @.***>
Not efficiently, you can't.
LocalOperator stores the matrix representation of XX which is a 4x4 matrix, or of ZXYZ which is 16x16, forgetting that the matrix is exponentially large, how do you take the matrix and identify the corresponding paulistring? I think it's an (NP?) hard problem.
To do what you say we'd need to have some tools to keep everything symbolically and do symbolic manipulation. But we don't have those yet.
This is interesting, I think it can be done efficiently using some matrix décompositions but let's discuss this with a blackboard. I am almost positive this is not NP hard, unless you look for the optimal decomposition maybe, but might not be necessary
On Fri, Sep 23, 2022, 15:43 Filippo Vicentini @.***> wrote:
Not efficiently, you can't. LocalOperator stores the matrix representation of XX which is a 4x4 matrix, or of ZXYZ which is 16x16, forgetting that the matrix is exponentially large, how do you take the matrix and identify the corresponding paulistring? I think it's an (NP?) hard problem.
To do what you say we'd need to have some tools to keep everything symbolically and do symbolic manipulation. But we don't have those yet.
— Reply to this email directly, view it on GitHub https://github.com/netket/netket/pull/1335#issuecomment-1256233988, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWYRBESYMDAWHYL2TNW2LTV7WXZNANCNFSM6AAAAAAQTZI3FM . You are receiving this because you commented.Message ID: @.***>
Still, I don't see why we need this...
I think the actual good thing would be to have some decent Symbolic CAS within netket (we already have two half-sketched ones, both written by @jwnys , for PauliStrings and FermionicOperators2nd). Let's unify them, make them work with Fock spaces as well, and then give a way to convert a SymbolicOperator to LocalOperator or PauliString (maybe even warning if the operator is K-Local with K=6 and is non diagonal that it's going to be very not efficient....
I agree a synbolyc tool is needs, I am just arguing we maybe don't need to write two hard-coded fully performant Hamiltonians, we just need to write one and convert all we can in that format. Of course the issue is if this can be done efficiently, but I believe it should be possible
On Fri, Sep 23, 2022, 15:51 Filippo Vicentini @.***> wrote:
Still, I don't see why we need this...
I think the actual good thing would be to have some decent Symbolic CAS within netket (we already have two half-sketched ones, both written by @jwnys https://github.com/jwnys , for PauliStrings and FermionicOperators2nd). Let's unify them, make them work with Fock spaces as well, and then give a way to convert a SymbolicOperator to LocalOperator or PauliString (maybe even warning if the operator is K-Local with K=6 and is non diagonal that it's going to be very not efficient....
— Reply to this email directly, view it on GitHub https://github.com/netket/netket/pull/1335#issuecomment-1256243995, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWYRBHQBKQDBYCVS3P6EN3V7WYURANCNFSM6AAAAAAQTZI3FM . You are receiving this because you commented.Message ID: @.***>
I think we can efficiently convert any practical local operator to Pauli strings. Each term in the local operator is a 2^N * 2^N Hermitian matrix with 4^N independent real parameters. To convert it to a sum of some Pauli strings, we first list all 4^N Pauli strings of length N, which form a complete basis for the space of all 2^N * 2^N Hermitian matrices. Then we solve a linear system with 4^N real parameters to determine the coefficients of those Pauli strings. Although 4^N is exponentially large, for really 'local' operators with small N we can still efficiently do it.
If the operator is not Hermitian, there will be 4^N independent complex parameters, and in that case we allow the coefficients of Pauli strings to be complex.
I did some benchmarks but I cannot see any significant time or memory difference between HeisenbergJax and PauliStringsJax. Maybe we should really move everything to Pauli strings? (But I'm not sure how Pauli strings work with more than 2 local states)
In actual use cases I think the time for evaluating the neural network is always much larger than the time for get_conn_padded (as long as the memory allocation is working normally), and why I wanted the JAX version is mainly because of the memory bottleneck.
I think we can efficiently convert any practical local operator to Pauli strings. Each term in the local operator is a 2^N * 2^N Hermitian matrix with 4^N independent real parameters. To convert it to a sum of some Pauli strings, we first list all 4^N Pauli strings of length N, which form a complete basis for the space of all 2^N * 2^N Hermitian matrices. Then we solve a linear system with 4^N parameters to determine the coefficients of those Pauli strings. Although 4^N is exponentially large, for really 'local' operators with small N we can still efficiently do it.
Indeed this seems a possible way to go, I wonder if we can further optimize this procedure to remove some strings, but otherwise this is a strategy already usable !
I did some benchmarks but I cannot see any significant time or memory difference between HeisenbergJax and PauliStringsJax. Maybe we should really move everything to Pauli strings?
Maybe let's do some more tests but if we can remove these specialized hamiltonians it would be just much more elegant
Indeed this seems a possible way to go, I wonder if we can further optimize this procedure to remove some strings, but otherwise this is a strategy already usable !
We can optimize the strings (like combining like terms) after solving all the linear systems, which I think Jannes already implemented
@wdphy16 , of courser I'm in favour of this, but I want it to be implemented correctly.
to better understand the implications of the changes you are trying to make, could you please make a separate PR with only the changes to Ising in it, essentially a cleanup of the original PR by @inailuig?
Also, could you please document how the API would work out for general operators? Would we end up having 1 single Ising or BoseHubbard class in Jax, or two of them (jax and numba)? How does the user pick and switch from one to the other?
Is this flexible enough?
Moreover, In order to keep the in think it would be better to create a new abstract Mixin class DiscreteJaxOperator and all 'jax' operators should inherit from it as well, and it would be useful to see a proposal for what this interface would entail.
Also, all Jax-aware operators should be declared as pytrees so that they can be passed as jax arguments.
I was just curious: what is the outcome then? is there a speed up or a slow down when using this Jax-aware implementation of the Ising hamiltonian?
Ising yes, all the others not.
all the others you mean Heisenberg is slower ?
yes. The exchange term in the Heisenberg
[0, 0, 0, 0],
[0, 0, 2, 0],
[0, 2, 0, 0],
[0, 0, 0, 0],
has an dynamical number of connected elements depending on its value. In Numba we manage to prune some of those. In jax we have to go with worst case scenario which leads to quite a bit of wasted computations.
When Jax will give us dynamical shapes (they say by the end of the year) we could get back to Heisenberg...
In many of my use cases the jax version of Heisenberg is slower by ~20%, so I agree that we wait until jax supports dynamically shaped arrays
I would still like to merge IsingJax and have it as the default implementation, but Dian refutes my commandments.
Maybe there is no need to maintain a separate implementation of Ising because we can move everything to Pauli strings
doing PauliStringJax once possible would be ideal sure
Not you as well... 😭 Dian!
I think we can efficiently convert any practical local operator to Pauli strings. Each term in the local operator is a
2^N * 2^NHermitian matrix with4^Nindependent real parameters. To convert it to a sum of some Pauli strings, we first list all4^NPauli strings of lengthN, which form a complete basis for the space of all2^N * 2^NHermitian matrices. Then we solve a linear system with4^Nreal parameters to determine the coefficients of those Pauli strings. Although4^Nis exponentially large, for really 'local' operators with smallNwe can still efficiently do it.If the operator is not Hermitian, there will be
4^Nindependent complex parameters, and in that case we allow the coefficients of Pauli strings to be complex.
I was just reading this. The product of Pauli matrices form an orthogonal basis in the space of 2^n x 2^n, so no need to solve any linear system, you can just project. Here's a sketch of how to obtain the coefficients (from which you can directly obtain the PauliStrings) in the Pauli basis of any matrix of size 2^n x 2^n:
import jax
import jax.numpy as jnp
import numpy as np
I = jnp.eye(2)
X = jnp.array([[0, 1], [1, 0]])
Y = jnp.array([[0, -1j], [1j, 0]])
Z = jnp.array([[1, 0], [0, -1]])
pauli_basis = jnp.stack([I, X, Y, Z], axis=0)
@jax.jit
def tensor_product(mats1, mats2):
# all combinations between two sets via tensorproduct
if mats1.ndim == 2 and mats2.ndim == 2:
return jnp.kron(mats1, mats2)
elif mats1.ndim == 3:
return jax.vmap(tensor_product, in_axes=(0, None))(mats1, mats2)
elif mats2.ndim == 3:
return jax.vmap(tensor_product, in_axes=(None, 0))(mats1, mats2)
def basis_till_n(n=1):
# return basis till 2^n x 2^n size
# output is dict with keys 2^n and values (# basis elements for n) x 2^n x 2^n
if n == 1:
return {2 : pauli_basis}
else:
lower_bases = basis_till_n(n=n-1)
tp = tensor_product(pauli_basis, lower_bases[2**(n-1)])
tp = tp.reshape(-1, *tp.shape[-2:])
return {**lower_bases, (2**n):tp}
@jax.jit
def hilbert_schmidt(A, basis_element):
# hilbert schmidt product but the basis elements are hermitian (so no basis_element.conj().T here)
return jnp.trace(basis_element @ A)
@jax.jit
def project_matrix(A, bases):
# project onto basis element according to hilbert schmitt product
assert A.ndim == 2
assert A.shape[0] == A.shape[1]
size = A.shape[0]
basis_set = bases[size]
return jax.vmap(hilbert_schmidt, in_axes=(None, 0))(A, basis_set) / size
# let's run a test
bases = basis_till_n(n=3)
test_matrix = 1j*np.kron(I, Z)+2*np.kron(Z, Z)
out = project_matrix(test_matrix, bases)
nz = np.where(out)[0]
assert np.allclose(nz, np.array([3, 15])) # check the right bases are there
assert np.allclose(out[nz], np.array([1j, 2]))
So having #1510 would be sufficient to map also LocalOperator to Jax.