Jax slow for contractions and jit
I am trying to reproduce the MNIST classification task with an MPS ansatz and encountered very slow compilation times when using jit. And even without jitting, the MPS contraction seems to be slower for jax than for the numpy or pytorch backends. I am using jax 0.2.0, tensornetwork 0.4.1, but I also tried master (https://github.com/google/TensorNetwork/commit/a27f269ce6b1738bbf2ffd4df7df8422646fb01a). See code below for a simple example:
import time
import numpy as np
import jax.numpy as jnp
from jax import jit, grad, vmap
import tensornetwork as tn
tn.set_default_backend("jax") # change to numpy, pytorch, tensorflow
@jit
def contract(params, inputs):
""" Contracts the input and parameter MPS """
params_mps = [tn.Node(params[0])] + [tn.Node(t) for t in params[1]] + [tn.Node(params[2])]
input_mps = [tn.Node(t) for t in inputs]
sites = len(input_mps)
params_mps[0][1] ^ params_mps[1][1]
[params_mps[k][2] ^ params_mps[k+1][1] for k in range(1, sites-1)]
[params_mps[k][0] ^ input_mps[k][0] for k in range(sites)]
contraction = tn.contractors.greedy(params_mps + input_mps, ignore_edge_order = True) # contraction could be parallelized
return contraction.tensor
@jit
def predict(params, inputs):
""" Computes mean of tensor network contraction over batch of input """
result = vmap(contract, in_axes=(None, 0))(params, inputs)
return jnp.mean(result)
sites = 784
d_phys = 2
d_bond = 3
inputs = np.random.uniform(size=(50, sites, d_phys))
left_boundary = np.random.normal(size=(d_phys, d_bond))
right_boundary = np.random.normal(size=(d_phys, d_bond))
center = np.random.normal(size=(sites-2, d_phys, d_bond, d_bond))
params = [left_boundary, center, right_boundary]
start_time = time.time()
contract(params, inputs[0])
print("Done in {:0.5f} sec".format(time.time() - start_time))
start_time = time.time()
contract(params, inputs[1])
print("Done in {:0.5f} sec".format(time.time() - start_time))
start_time = time.time()
jit(grad(predict))(params, inputs)
print("Done in {:0.5f} sec".format(time.time() - start_time))
This results in: Jax (without jit): 0.79403s, 0.68542s Jax (with jit): 22.99815s, 0.00026s Numpy: 0.34218s, 0.32759s Pytorch: 0.68840s, 0.34381s Tensorflow: 2.75973s, 0.83503s
The times for calculating the gradient are: Jax (without jit): 30.26929s Jax (with jit): 602.86008s
Is there something I am doing wrong? Are there ways to speed things up?
(I know that the contraction scheme could be parallelized as mentioned in the paper. However, later I want to use the code for entangled input states for which the pairwise contraction scheme would be suboptimal again.)
Thanks @frmetz for the issue! When timing JAX functions, you should always call .block_until_ready() on the result of the function prior to measuring runtimes. JAX asynchronously dispatches instructions, and block_until_ready() forces the computation to finish before executing the next python statement. Otherwise, you are only measuring dispatch times.
That said, the timings for contract I get for numpy and JAX are:
numpy: ~350ms
JAX:
- first run of jitted
contract: ~20s - subsequent runs of jitted
contract: ~700µs The large tracing overhead of the first run is likely due to a python for loop deeper down in the code that is being unrolled upon callingjit. Usually one would directly uselax.while_loop,lax.scanorlax.fori_loopin this case to reduce the compilation time, tho in this case this is tricky.
What you should do in this case is to identify the computationally heavy parts of your code, wrap it into functions and jit them at the beginning of your code. At the first invocation of the function you'll have to suffer the tracing overhead, but subsequent calls will be substantially faster (modulo potential retracing if certain proerties of the args change).
Hope this helps!
here is also a slightly changed version of your code where JAX and numpy perform about equal (at least for num_samples=1)
import time
sites = 784
d_phys = 2
d_bond = 3
n_samples = 1
tensors =[randn(1, d_phys, d_bond)] + [randn(d_bond, d_phys, d_bond) for _ in range(sites-2)] + [randn(d_bond, d_phys, 1)]
samples = [randn(n_samples, d_phys) for _ in range(sites)]
def contract_2(mps_tensors, sample_tensors):
res = tn.ncon([mps_tensors[0], sample_tensors[0]],[[-2,1,-3],[-1,1]])
for n, tensor in enumerate(mps_tensors[1:]):
res = tn.ncon([res, tensor, sample_tensors[n]],[[-1,-2,1],[1,2,-3],[-1,2]])
return res
contract_2_jit = jit(contract_2)
tn.set_default_backend('numpy')
print('numpy runtimes')
%timeit contract_martin(tensors, samples)
tn.set_default_backend('jax')
inputs = [jnp.array(t) for t in tensors], [jnp.array(t) for t in samples]
t1 = time.time()
contract_2_jit(*inputs).block_until_ready()
print('tracing contract_2_jit: ',time.time() - t1)
print('jitted contract_2_jit')
%timeit contract_2_jit(*inputs).block_until_ready()
print('unjitted contract_2')
%timeit contract_2(*inputs).block_until_ready()
Thanks @mganahl! This was very helpful. Your code example actually accidentally solved another issue I had :)
Just to be sure that I am understanding it correctly: The for loops that are likely to increase the compilation times in my example are due to loops inside the TensorNetwork functions I call, right? So there is nothing really I can do to further speed up this part of the computation? Thanks again!