jax
jax copied to clipboard
laplacian experiment
w/ @dpfau and @jsspencer
docs:
hbf
= "hessian bilinear form" = $(f, x, u, v) \mapsto (\partial f(x) v, \quad u^\mathsf{T} \partial^2 f(x) v)$
hqf2
= "hessian quadratic form 2D" = $(f, x, A) \mapsto (\partial f(x) A, \quad A^\mathsf{T} \partial^2 f(x) A)$
The generic rule might be quite efficient. In that case, we may only need rules for higher-order primitives (pjit, scan, etc), and that's just for efficiency; the generic rule should work on those too. The rule for pjit
is included here.
Here's the generic rule, slightly distilled:
def generic_rule(prim, in_primals, in_jacs, in_laps, **params):
f = partial(prim.bind, **params)
out_primals, jac_term = jax.jvp(f, in_primals, in_laps)
out_jacs, hess_term = hqf2(f, in_primals, in_jacs)
out_lapvecs = tree_map(op.add, trace(hess_term), jac_term)
return out_primals, out_jacs, out_lapvecs
trace = partial(tree_map, partial(jnp.trace, axis1=-1, axis2=-2))
def hbf(f, xs, vs1, vs2):
return jax.jvp(lambda *xs: jax.jvp(f, xs, vs1)[1], xs, vs2)
def hqf2(f, xs, vs):
return jax.vmap(jax.vmap(partial(hbf, f, xs), (-1, None), -1),
(None, -1), (None, -1))(vs, vs)
Here's a derivation, inspired by https://arxiv.org/abs/2307.08214:
It's fantastic to see jax developers working on this as well! Please allow me to (shamelessly) put my previous effort here (https://github.com/y1xiaoc/fwdlap), which follows similar idea of writing a new tracer and might be of a little help. I was trying to learn the tracer framework during writing it, so the implementation is a bit nasty (with a lot of things borrowed from jvp and jet). The main issue I met was to respect the custom jvp rules, and the current implementation is quite hacky and may lead to error in some corner case, written in the issue.
@y1xiaoc thanks so much for reaching out! That looks really great! The readme is expertly written, and in the code I'm impressed you were able to figure out and use so much of JAX's internals. I bet we could learn a lot from your work.
I can see quite a lot of similarities (a good sign, I think!). I'm curious if you can already spot some differences, especially things you've figured out that I haven't yet.
Your readme mentions symbolic zeros. I made an attempt to handle them here (see pure
and lift
), but to be honest I haven't really tested them at all.
Your readme also mentions a special rule for bilinear operations. I actually wasn't sure yet whether the generic rule would be inefficient there. It seems like the sparsity of the Hessian should be taken into account, but perhaps we are missing some symmetry, like effectively computing x * y + y * x
rather than 2 * x * y
?
I'm glad you reached out because it's encouraging to know there is interest in this kind of computation. Maybe even enough to justify upstreaming something into JAX! Would you be interested in collaborating on that? (Maybe also on jet.py, since you're clearly familiar with it!) I was planning to reach out to the paper authors too at some point, though I've only just started looking at this, and haven't looked at their code yet (since sometimes it's more fun to implement things yourself from scratch).
@y1xiaoc one more question, your readme makes this comment about jet
:
However, the implementation of jet is very inefficient, because it will always instantiate all the symbolic zeros.
I thought jet.py did have symbolic zeros support. Could you say more about the inefficiency? I'd like to fix it, but I've forgotten everything about that code :)
IIUC this is the paper authors' repo: https://github.com/YWolfeee/lapjax
We should probably make our public API functions preserve symbolic zero information... I see the zero instantiation problem.
@mattjj Thanks a lot! Yes, I would definitely be more than happy to collaborate on this! Please let me know what you think would be the best way to do this. Currently as a phd student I'm pretty flexible! And yes, I think this is really useful for people doing real space variational quantum Monte Carlo. I've also heard from folks working on SDE mentioned they want a fast way to calculate laplacian.
For the differences, One thing I see (as you mentioned) is the treatment of symbolic zeros. I was trying to reuse the internal Zero
class defined in jvp in my implementation, which may allow in_jacs
and in_laps
to have different symbolic zero terms, and avoid this assert. The tradeoff is I have to use the internal version of jvp
instead of the public api.
Another thing I tried very hard is to make the tracer respect the custom_jvp
. I was only able to do it in a hacky way that the outer most lu transformation was stripped off. I actually asked a question in discussion before I figured out this hacky solution. But I was not able find a more proper way without changing the jax internals, maybe partially because I couldn't fully figure out the implementation of custom jvp framework.
For the bilinear rules, I think you are right:
but perhaps we are missing some symmetry, like effectively computing x * y + y * x rather than 2 * x * y?
is exactly the case. For large matrix doing 2 * x * y can provide a significant speedup.
For the jet question, the zero_series
are indeed created properly, but the problem is that every time in process_primitive
they will be instantiated. You actually left a TODO there. I guess for jet writing each rules to respect symbolic zeros takes too much effort...
I was also aware of the paper authors repo last month, and I think they open sourced it shortly after I did, but I was too busy to update my readme. I did a very naive benchmark on both codes, and the performance seems identical, except for some corner case where the function takes an array but is pure element wise (no matmul involved), for which the official code is faster. But at the time I tested the official code seems to have problems with complex numbers.
Hope these helps!
Over at @netket we'd love this to compute kinetic energies of quantum states, like @dpfau does. I've got a private branch where I use @y1xiaoc 's code internally, and we're quite happy but if this was supported upstream it would be amazing.
So it turns out even more people are interested in this! https://github.com/microsoft/folx
Amazing to see such an interest in getting efficient Laplacian computations into JAX and thanks @PhilipVinc for mentioning our implementation here! I'd like to elaborate a bit on the details in microsoft/folx, our custom interpreter for the forward laplacian. While I can seemingly still learn a lot about the undocumented JAX internals from @y1xiaoc :), to get the most important speedup in the forward laplacian as proposed in https://arxiv.org/abs/2307.08214 one actually to carry sparse Jacobians forward. In our implementation we solve this by annotating functions such that our interpreter can follow the sparsity patterns. A good example to illustrate this if we apply a function on a set of objects (which may be substantial parts of networks like PsiFormer https://arxiv.org/abs/2211.13672):
import folx
import jax
import jax.numpy as jnp
import fwdlap
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
for _ in range(10):
x = nn.Dense(100)(x)
x = nn.silu(x)
return nn.Dense(1)(x).sum()
mlp = MLP()
x = jnp.ones((20, 100, 4))
params = mlp.init(jax.random.PRNGKey(0), x)
def fwd(x):
return mlp.apply(params, x)
fwd_lapl = jax.jit(jax.vmap(folx.forward_laplacian(fwd, sparsity_threshold=4)))
%time jax.block_until_ready(fwd_lapl(x)) # Wall time: 5.59 s
%timeit jax.block_until_ready(fwd_lapl(x)) # 4.34 ms ± 23.4 µs per loop
@jax.jit
@jax.vmap
def fwdlapl_laplacian(x):
og_shape = x.shape
x_flat = x.reshape(-1)
eye = jnp.eye(x.size, dtype=x.dtype)
zero = fwdlap.Zero.from_value(x)
def flat_fwd(x):
return fwd(x.reshape(*og_shape))
return fwdlap.lap(flat_fwd, (x_flat,), (eye,), (zero,))
%time jax.block_until_ready(fwdlapl_laplacian(x)) # Wall time: 2.34 s
%timeit jax.block_until_ready(fwdlapl_laplacian(x)) # 253 ms ± 4.87 ms per loop
@n-gao Thanks for the introduction of folx! Maintaining sparsity in a custom interpreter is fantastic and indeed very important! I think the paper author's code was also trying to do this but by reloading numpy functions. I was not able to achieve that in my code because I thought it was too hard lol.
A small note: If I understand correctly, for the sparsity to work, the batched structure must be maintained, which means the second dimension of x cannot be mixed. So normal attention layer may not get benefit from this? A quick check is if I change MLP to the following one, the speed difference is less pronounced (folx 20ms vs mine 27ms). And that difference is mainly a compiling issue, if I switch tanh_p and logistic_p to joint jvp implementation my code can actually be slightly faster (18ms). But that change will slow down the "production" network in my current project.
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
for _ in range(5):
x = nn.Dense(100)(x.swapaxes(-1, -2)).swapaxes(-1, -2)
x = nn.silu(x)
x = nn.Dense(100)(x)
x = nn.tanh(x)
return nn.Dense(1)(x).sum()
BTW I think there might be a bug for the sparsity implementation in folx?
Changing the shape of x into x = jnp.ones((23, 101, 5))
will break the sparsity speedup, even for the old non-mixed MLP.
Hi @y1xiaoc thanks a lot for your interest :)
Regarding sparsity: The batched structure does not have to be maintained, folx
does not care the way you transpose or reshape your tensor but it has a sparsity threshold as a hyperparameter. In your MLP every element depends on 100 other elements after the first dense layer and is fully dependent after the first layer. So, there's no sparsity in Jacobian and folx will simply materialize the full matrix. This is also the reason why normal attention does not benefit a lot from this, unfortunately. In the Forward Laplacian paper, the authors present an alternative attention mechanism that includes sparse Jacobians.
Surprisingly, when I can your MLP compared to folx on a 1080ti with the same batch as in my first post, I get the following numbers (maybe due to some optimizations I did in matmuls and memory transfers? A100 with tensor cores enabled, typically have different bottlenecks as they also have HBM2 RAM vs GDDR5X)
folx: 156 ms ± 254 µs
fwdlap: 406 ms ± 4.74 ms
Regarding the behavior with the changed batch, did you also adjust the hyperparameter for the sparsity threshold? This unfortunately has to be set manually as it is very shape dependent on when it is best to materialize something. If you enter a fraction instead of an integer it is seen as a fractional of the full input size.
sparsity_threshold=4: 134 ms ± 332 µs
sparsity_threshold=5: 7.07 ms ± 30.1 µs
Regarding compile times: These typically are not jit
compile times but the numpy operations that I do to check the sparsity patterns. Some are very inefficiently implemented as compile times are typically not significant in my workloads compared to runtime. I have them on the radar and also would like to improve this in the future :)
Bugs are obviously still possible and it would be amazing if you could open a GitHub issue if you encounter any!
Hi all! Thank you for your interest in our work! I'm Ruichen Li, the first author of https://arxiv.org/abs/2307.08214. It's exciting to see our work benefiting the JAX community and generating fruitful discussions! Using JaxTracer to handle the Laplacian calculation is quite clear and fascinating -- I can learn a lot from it!
While our implementation relies on the function overloading, we may encounter the same challenges in developing the code. Here, I'd like to share some experiences.
The zero instantiation, i.e., sparsity in the Hessian matrix of an operator, is one of the main challenges when applying the generic rule to arbitrary functions. In our implementation, we address two kinds of sparsity in the Hessian, zero matrix (linear operator) and diagonal matrix (element-wise operator). As noted by @y1xiaoc, this kind of zero instantiation might not pose a fundamental problem when leveraging the powerful jit function and JaxTracer. For the other kinds of sparse Hessian matrix, e.g., the Hessian matrix of a bilinear operator, we must write down the propagation rule case by case.
Besides the sparse matrix, we've explored accelerating the quadratic form via matrix decomposition, particularly when dealing with a low-rank Hessian matrix. In our implementation, there is only one case of this kind of operator (logdet
). One can find how to leverage the symmetry and low-rank property for this operator in https://github.com/YWolfeee/lapjax/blob/main/_lapsrc/function_class.py, line 451 to 473. Also, I notice that @n-gao’s implementation address the propagation rule in a similar way and his explanation is very clear! (https://github.com/microsoft/folx/blob/main/folx/custom_hessian.py, line 14 to 59)
Another challenge arises from the sparsity in the Jacobian. In our implementation, we store the sparsity information within the input tuple and carefully process them to enhance the calculation speed. Despite our effort, the user should still take care about sparsity for some complex structures, e.g, attention block in the LapNet. Amazingly, @n-gao’s implementation can automatically detect the sparsity. I’m curious if this implementation can accelerate the attention part in LapNet where the Jacobian matrix (tensor) would be the sum of two Kronecker factors, e.g., $\delta_{ij} + \delta_{kj}$, rather than a single Kronecker factor in the envelope function.
The final problem is about JAX version dependency. We developed our code under jax==0.3.24
. However, when upgrading to JAX 0.4.23 (the latest version), we observed a significant performance decrease (about 40%). It may be attributed to the change of jax.jit
. As a suggestion, I recommend checking if a lower version of JAX yields faster execution speeds for the Forward Laplacian.
Once again, I'd like to express my appreciation for all your interest. If you have other problems about the implementation, please do not hesitate to discuss them with me.
Thanks a lot @dashu233 for your insights on this; great work! :) Regarding the JAX version: this is rather curious, I actually haven't tested any older version with folx. Maybe @mattjj can shed some light on this? I guess it would be quite crucial for JAX to track performance degradation.
Regarding automatic sparsity detection, it is not 100% ideal as it relies on elementary operations but some operations are faster if one merges a few operations like flash attention. While this is supported in folx, I haven't implemented faster attention yet. But yes, folx will detect combined sparsities automatically. For instance, in this example, the sparsity patterns would be in the comments.
x = np.random.normal(size=(16, 4))
layers = [
np.random.normal(size=(4, 32))
for _ in range(3)
]
@jax.jit
def attention(x): # J_x is (16, 4, 1)
q, k, v = [jnp.dot(x, w) for w in layers] # J_q, J_k, J_v are (16, 32, 4)
A = [email protected] # J_A is (16, 16, 8) as it only depends on two elements
A = jax.nn.softmax(A, axis=-1) # J_A is now (16, 16, 64) as it depends on all elements due to the softmax
return jnp.dot(A, v) # J_result is (16, 32, 64)
As sparsity is tracked on a per-element level, it also doesn't care for reshapes, indexing, or transposes and will still work correctly. The only caveat is that the sparsity is always defined by the largest dependency in the tensor and it is never checked if the sparsity could be reduced again, so it can only grow denser. Another disadvantage is that densifying matrices is very costly, requiring a segment sum.
So in a LapNet the sparsity would be preserved until the softmax, at that point, it will default to a dense implementation. Due to the cost of materializing sparse matrices, I found that the sparse attention in LapNet provides less speedup than I'd have hoped for. It is still quite a bit faster but most of the speedup comes from the envelopes. Is this similar to your implementation?
Thanks @n-gao for the reply! It's fantastic that you can detect the sparsity in that way! I guess I should take a lot of time to learn from your codes :)
Regarding the LapNet speed-up rate, in our experiments, it is 1x to 2x faster than Psiformer with Forward Laplacian in systems with around 60 electrons. While this improvement is relatively modest compared to the speed-up rate achieved through the envelope, our implementation of sparse derivative attention incorporates specific tricks to achieve this acceleration. The key insight is a two-step separation of the softmax
operation, introducing more sparsity in the Jacobian matrix. As a suggestion, you may have a try on the following codes:
def attention(x): # J_x is (16, 4, 1)
q, k, v = [jnp.dot(x, w) for w in layers] # J_q, J_k, J_v are (16, 32, 4)
A = [email protected] # J_A is (16, 16, 8) as it only depends on two elements
# two step softmax
expA = jnp.exp(A) # J_expA is also (16, 16, 8)
sum_expA = jnp.sum(expA, axis=-1, keepdims=True) # J_sum_expA is (16, 1, 64)
expA_v = jnp.dot(expA, v) # This operation would be faster than softmax(A)@v, as the expA has a sparse Jacobian
output = expA_v/sum_expA # the output has a dense Jacobian
return output
This code is just a speed test. In the open-sourced implementation of LapNet, we added some tricks used in softmax
to avoid the numerical stability problem.
Hope this can help you!
The final problem is about JAX version dependency. We developed our code under jax==0.3.24. However, when upgrading to JAX 0.4.23 (the latest version), we observed a significant performance decrease (about 40%). It may be attributed to the change of jax.jit. As a suggestion, I recommend checking if a lower version of JAX yields faster execution speeds for the Forward Laplacian.
@dashu233 can you provide an easily runnable script that exhibits this speed difference? I want to look into it. Is that on GPU?
@mattjj Thank you for your kindness! It is truly thrilling to receive the guidance from a JAX expert!
Extracting a runnable script from our codes has taken more time than I expected. Unfortunately, I couldn't reproduce that behavior on a simple MLP structure. the time cost of our implementation on an MLP structure is almost the same between jax==0.3.24 and jax==0.4.23. Instead, I found that including almost all the main parts of Psiformer, a widely-used network architecture in my research area, is necessary to reproduce the performance discrepancy. I simplified the architecture for readability, and the performance discrepancy is now about 20%.
A runnable script to reproduce the performance discrepancy is attached. One should first install the jax with cuda support and then install our LapJAX package (https://github.com/YWolfeee/lapjax).
git clone https://github.com/YWolfeee/lapjax
pip install lapjax
The expected output time cost is 3.10 $\pm$ 0.04 for jax==0.3.24 and 3.82 $\pm$ 0.02 for jax==0.4.23 on a Tesla V100-32g GPU.
import jax
import lapjax
import lapjax.numpy as ljnp
from lapjax import LapTuple
import time
print(jax.__version__)
# some hyperparameters
batch_size = 32 # This batch size is suitable for V100-32g
input_dim = 3
node = 84 # number of node in a graph.
layer = 2
hid_dim = 128
det = 16 # in NN-VMC, we use the output of network to construct some determinants. This refers to the number of determinant.
key = jax.random.PRNGKey(1234)
key, subkey = jax.random.split(key)
atom_num = 24
atom = jax.random.normal(subkey, (atom_num, input_dim))
# this is a network based on attention blocks.
def psiformer(x, params, env_param):
x_in = x # shape = (node, input_dim)
# attention part in the network
for w in params[:-1]:
x = ljnp.tanh(ljnp.matmul(x, w))
attn_map = lapjax.nn.softmax(ljnp.matmul(x, x.transpose()))
x = ljnp.matmul(attn_map, x)
# the Jacobian of x is dense
x = ljnp.matmul(x, params[-1]) # shape = (node, det * node)
# the envelope function is a single particle function, i.e.
# envelope = vmap(_some_special_function)(x_in)
# so the Jacobian of envelope function is sparse.
# here we leverage the broadcast rule so that we don't have to call jax.vmap
dis = ljnp.linalg.norm(x_in[:, None, :] - atom[None, ...], axis=-1)
envelope = ljnp.sum(ljnp.exp(dis[..., None] * env_param), axis=1) # shape = (node, det * node)
# dense Jacobian multiply with sparse Jacobian
x = x * envelope
# reshape to (det, node, node)
x = x.reshape(x.shape[0], -1, x.shape[0])
x = x.transpose((1, 0, 2))
# in NN-VMC, we have to use the determinants to ensure the anti-symmetry property
output = lapjax.nn.logsumexp(ljnp.linalg.slogdet(x)[1])
return output
# params and data initialization
hidden_dims = [input_dim] + [hid_dim] * layer + [node * det]
params = []
for i in range(layer + 1):
key, subkey = jax.random.split(key)
params.append(jax.random.normal(subkey, (hidden_dims[i], hidden_dims[i + 1])))
env_params = jax.random.normal(key, (atom_num, det*node))
input = jax.random.normal(key, (30, batch_size, node, input_dim))
def fwdlap(x:ljnp.array):
x = LapTuple(x, is_input=True)
# Replace it with the following code will remove the sparsity support in Jacobian
#
# x = LapTuple(x, is_input=True).set_dense(True)
#
# Then one will find Jax=0.4 would be faster than Jax=0.3
# but it would be much slower than that with sparsity support.
output_tuple:LapTuple = psiformer(x, params, env_params)
v, g, l = output_tuple.to_tuple()
return l
batch_fwd_lap = jax.jit(jax.vmap(fwdlap))
for i in range(10):
batch_fwd_lap(input[i])
start_time = time.time()
for i in range(10,30):
batch_fwd_lap(input[i])
print(time.time() - start_time)
Brilliant progress in two years! The proposed change could eventually solve https://github.com/google/jax/discussions/9598.
@mattjj 's derivation is clear and clever. FWIW, here is another derivation from an operator perspective, directly applying chain rule without hessian: https://math.stackexchange.com/a/247926
@mattjj what's the status on this? Do you plan on merging it?
I'm not sure! It was just an O(1) day hack, and the other libraries linked here are a lot better. Still I do wonder if we could put some machinery inside JAX so it's easy to extend its autodiff in this kind of way, so that libraries have an easy time and don't have to touch JAX internals so much...
Do other folks on this thread have thoughts? What would be useful to upstream into JAX, if anything?
From my side, I am okay with the set of libraries available elsewhere (obv. including my own). While a simple implementation like fwdlap might be convenient, the whole sparsity tracking pipeline etc. doesn't seem like a good fit to core JAX (at least not the way I implemented it).
While developing folx, I encountered many issues where I had to resort to non-public/non-documented APIs. Improvements on this front would be greatly appreciated. One example would be using eager execution during tracing. This is heavily used within folx to implement JVPs of index operations on the index mask, access to compile-time fixed tensors (for indexing), or simply as a convenience to "run numpy functions with vmap" during tracing. I also had to look through JAX's source code and do a lot of trial and error to determine evaluations of custom jvps or jitted functions. Another point would be getting source code details about traced functions where I had to use JAX's internal jax._src.source_info_util
.
I guess libraries like https://github.com/patrick-kidger/quax (which wasn't public when I wrote my implementation) are also helpful but do not alleviate all the pain points from the previous paragraph.
A batched vmap/scan implementation like here https://github.com/microsoft/folx/blob/main/folx/vmap.py might be a useful function for different applications. While this is pure convenience, it would allow trade-offs between parallelism and memory. My implementation might not be ideal as it requires inline the computation up to two times (though I wonder whether that can be avoided).
From my point of view, aka that of a maintainer of a library that could benefit from such fast laplacian calculations, but with no desire/bandwith to work directly on them...
I don't feel confortable depending on third-party packages that heavily rely on jax internals, because i've seen firsthand that those do change and lead to crashes, errors and problems on which we have no control.
If such third party packages could be built while relying on top of public or semi-public (jax.extend?) APIs, i would feel more confortable.
In my opinion, it would be better to develop the fwdlap algorithm and Jacobian sparsity support separately. The fwdlap algorithm is quite general and could benefit a broader community, while achieving the optimal performance for Jacobian sparsity requires significant effort and is model-dependent. Therefore, leaving the Jacobian sparsity support to the users or third-party developers seems more reasonable. A powerful custom_fwdlap may help.
By the way, I'm planing to attend the ICLR this year and may give a poster presentation on fwdlap at the AI4DiffEquation workshop. This paper proposes a method to generalize the forward Laplacian to various second-order operators without extensive code modifications. This generalization also enables a way that uses fwdlap to compute Hessian. One can perform fwdlap and then compute its gradient with respect to the coefficient matrix to derive the Hessian matrix. @mattjj I hope this can motivate you to consider integrating fwdlap into the core JAX :)
For me the most straightforward thing is to have the (semi-)public API for symbolic zeros. I think this can be beneficial to a lot of applications that requires tweak around the gradient evaluations. We have that in custom_jvp already but the standard jvp
function seems not supporting it.
In general, I totally agree with @PhilipVinc 's point on the semi-public APIs. I also had that in mind when writing fwdlap
, and tried not to use things in _src
. But there are still some internals that I have to play with, mainly the following (apart from the inevitable core
for tracers):
-
interpreters.ad
: as mentioned previously, basically I have to usead.jvp
to bypass the public interface to support symbolic zeros. -
interpreters.partial_eval
: I need this to make the partial eval version, likelinearize
forjvp
, to be used in loop. Also for tracing in custom jvp and pjit. I think an organized interface for partial_eval would be highly desirable, in particular given that we havejax.make_jaxpr
, to whichpartial_eval
would be a natural extension. - some helper functions to deal with the tracer api boundry, e.g. flatten the inputs, etc. Some of these may be used together with
linear_utils
which is already in the extend api.
I would say what I was tryinig to do might inevitably touch the core part of jax (I'm writing a tracer after all). But having those APIs as (semi-)public would make life much easier! Ultimately, I would hope there's a public interface to define custom tracers. It would be a waste if such a beautiful framework is not used by the community.