jax icon indicating copy to clipboard operation
jax copied to clipboard

Questions about control flow and indexed assignment

Open duhd1993 opened this issue 4 years ago • 8 comments

Hi, I want to implement a finite element solver with jax. I made a quick demo, which runs well without JIT but extremely slow. There are two problems:

  1. How do I deal with control flow? For instance, I have if np.isclose(np.abs(np.linalg.det(jacobian), 0.0): raise ValueError(). Since this will run for every element, I cannot just leave it in python. lax.cond seems not suitable for this purpose. Other examples could be calling different material model function, based on element info stored in an array.

  2. How do I do indexed assignment? This is a relatively smaller issue. At least I managed doing it with np.where and jax.ops.index_add. But these two methods are creating instead doing in-place update and it requires some helping arrays. lax.dynamic_update_slice can only update a contiguous sub-array. Also it requires creating a sub-array first. Is there something that could do indexed update/add? I'm not very familiar with tensorflow, but it seems you can do something like my_var[4:8].assign(tf.zeros(4)) . This is not exactly what I want, my_var[4:8].add(tf.zeros(4)) would be better, but it looks better at least. Or else, I should not worry about the memory and performance cost of creating a new array?

Thank you!

duhd1993 avatar Oct 31 '19 14:10 duhd1993

For indexed assignment, check out the jax.ops package and e.g. jax.ops.index_update. (If you search the README for "indexed assignment" you'll see a reference to this package.) Those operations don't have in-place assignment semantics, i.e. they're functionally pure, but you can expect that under jit they will have the same performance as in-place assignment (provided automatically by the compiler!).

For control flow, the short answer is you should think in terms of np.where. For more detail, could you provide a concrete, toy example?

mattjj avatar Oct 31 '19 22:10 mattjj

Hi, I want to implement a finite element solver with jax

I just wanted to say that I'm very excited about this sort of thing!

In case it helps, we wrote a version simple/naive finite element solver in autograd (which is conceptually quite similar to JAX) in https://github.com/google-research/neural-structural-optimization

shoyer avatar Nov 01 '19 00:11 shoyer

Thank you for the reply. I formulated a toy example. It's not very tiny because I need to address issues in a general finite element problem. colab. It requires some work, currently it runs slower than original numpy. Control flows: Cond(cannot jit)

if iet == 1:
  ndof, nnodes, ngpts = (8, 4, 4)
elif iet ==2:
  ndof, nnodes, ngpts = (12, 6, 7)
if np.isclose(np.abs(det), 0.0):
   msg = "Jacobian close to zero. Check the shape of your elements!"
   raise ValueError(msg)

Loop (managed with lax.scan, this seems a lot slower on GPU)

def assem(KG, el):
    kloc, ndof, iet  = retriever(elements, mats, nodes, el)
    dme = DME[el, :ndof]
    idx = np.ix_(dme, dme)
    kloc = np.where(np.logical_or(idx[0] == -1,idx[1] == -1), 0., kloc)
    KG = jax.ops.index_add(KG, idx, kloc)
    return KG,None
KG, _ = lax.scan(assem,KG,np.arange(nels))

@shoyer The paper looks very interesting. The difference is that I'm optimizing w.r.t to material property instead of parameters/geometry in your case. I'm happy to see autograd works for you. I was thinking it only suitable for very small scale problem. I see an issue there talking about sparse matrix and solvers. Did you implement that yourself? would you do a PR?

duhd1993 avatar Nov 01 '19 02:11 duhd1993

lax.scan invokes a while loop in XLA, which as you've noticed can be very slow on GPUs if the operations in each loop are very small and they don't get fused together into a single CUDA kernel.

You will do far better if you can figure out how to write this without scan, either by explicitly vectorizing the operation or by making use of JAX's vmap.

For autograd, I wrote a small wrapper for Cholmod's sparse matrix solve, which I think is about as well as you can do with direct methods on a single CPU. Autograd adds only a small amount of overhead (I suspect JAX could be a little better, but not much) -- majority of the runtime is spent in the sparse solve. I commented on the autograd issue with a link to my specific code: https://github.com/HIPS/autograd/issues/433

shoyer avatar Nov 01 '19 06:11 shoyer

Yeah. I manage getting rid of scan using vmap. Specifically for jax.ops.index_add, it's only accepting one update and index at a time. I have to create a full tensor for each update and sum them up. This is a lot faster but blows up memory easily O(N^3). Looking forward to sparse array/solver in JAX.

duhd1993 avatar Nov 01 '19 22:11 duhd1993

This is a lot faster but blows up memory easily O(N^3)

You should be able to accumulate array values in an efficient fashion by stacking them in coordinate format ("COO" for scipy.sparse). Then you can use something like my Autograd example to actually do the sparse solve.

shoyer avatar Nov 01 '19 23:11 shoyer

Thank you for the reply. I formulated a toy example. It's not very tiny because I need to address issues in a general finite element problem. colab. It requires some work, currently it runs slower than original numpy. Control flows: Cond(cannot jit)

if iet == 1:
  ndof, nnodes, ngpts = (8, 4, 4)
elif iet ==2:
  ndof, nnodes, ngpts = (12, 6, 7)
if np.isclose(np.abs(det), 0.0):
   msg = "Jacobian close to zero. Check the shape of your elements!"
   raise ValueError(msg)

Loop (managed with lax.scan, this seems a lot slower on GPU)

def assem(KG, el):
    kloc, ndof, iet  = retriever(elements, mats, nodes, el)
    dme = DME[el, :ndof]
    idx = np.ix_(dme, dme)
    kloc = np.where(np.logical_or(idx[0] == -1,idx[1] == -1), 0., kloc)
    KG = jax.ops.index_add(KG, idx, kloc)
    return KG,None
KG, _ = lax.scan(assem,KG,np.arange(nels))

@shoyer The paper looks very interesting. The difference is that I'm optimizing w.r.t to material property instead of parameters/geometry in your case. I'm happy to see autograd works for you. I was thinking it only suitable for very small scale problem. I see an issue there talking about sparse matrix and solvers. Did you implement that yourself? would you do a PR?

Have you tried using lax.cond instead of if-else for your conditional? It is jittable.

joaogui1 avatar Nov 03 '19 18:11 joaogui1

@joaogui1 Thanks. I haven't. It will look very ugly. For the first example, it will need nested lax.cond cuz I have many candidate values. Some sorts of lax.switch or dict object would be what I need here.

For the second example, something like tf.debugging.Assert would be more proper than doing it in python. Currently I just commented them out as there's one element type and no exceptions in my example and shoyer's paper.

duhd1993 avatar Nov 03 '19 18:11 duhd1993