TensorNetwork icon indicating copy to clipboard operation
TensorNetwork copied to clipboard

Jax slow for contractions and jit

Open frmetz opened this issue 5 years ago • 3 comments

I am trying to reproduce the MNIST classification task with an MPS ansatz and encountered very slow compilation times when using jit. And even without jitting, the MPS contraction seems to be slower for jax than for the numpy or pytorch backends. I am using jax 0.2.0, tensornetwork 0.4.1, but I also tried master (https://github.com/google/TensorNetwork/commit/a27f269ce6b1738bbf2ffd4df7df8422646fb01a). See code below for a simple example:

import time

import numpy as np
import jax.numpy as jnp
from jax import jit, grad, vmap

import tensornetwork as tn
tn.set_default_backend("jax") # change to numpy, pytorch, tensorflow

@jit
def contract(params, inputs):
	""" Contracts the input and parameter MPS """
	params_mps = [tn.Node(params[0])] + [tn.Node(t) for t in params[1]] + [tn.Node(params[2])]
	input_mps = [tn.Node(t) for t in inputs]
	sites = len(input_mps)

	params_mps[0][1] ^ params_mps[1][1]
	[params_mps[k][2] ^ params_mps[k+1][1] for k in range(1, sites-1)]
	[params_mps[k][0] ^ input_mps[k][0] for k in range(sites)]

	contraction = tn.contractors.greedy(params_mps + input_mps, ignore_edge_order = True) # contraction could be parallelized

	return contraction.tensor

@jit
def predict(params, inputs):
	""" Computes mean of tensor network contraction over batch of input """
	result = vmap(contract, in_axes=(None, 0))(params, inputs)
	return jnp.mean(result)

sites = 784
d_phys = 2
d_bond = 3

inputs = np.random.uniform(size=(50, sites, d_phys))

left_boundary = np.random.normal(size=(d_phys, d_bond))
right_boundary = np.random.normal(size=(d_phys, d_bond))
center = np.random.normal(size=(sites-2, d_phys, d_bond, d_bond))
params = [left_boundary, center, right_boundary]

start_time = time.time()
contract(params, inputs[0])
print("Done in {:0.5f} sec".format(time.time() - start_time))

start_time = time.time()
contract(params, inputs[1])
print("Done in {:0.5f} sec".format(time.time() - start_time))

start_time = time.time()
jit(grad(predict))(params, inputs)
print("Done in {:0.5f} sec".format(time.time() - start_time))

This results in: Jax (without jit): 0.79403s, 0.68542s Jax (with jit): 22.99815s, 0.00026s Numpy: 0.34218s, 0.32759s Pytorch: 0.68840s, 0.34381s Tensorflow: 2.75973s, 0.83503s

The times for calculating the gradient are: Jax (without jit): 30.26929s Jax (with jit): 602.86008s

Is there something I am doing wrong? Are there ways to speed things up?

(I know that the contraction scheme could be parallelized as mentioned in the paper. However, later I want to use the code for entangled input states for which the pairwise contraction scheme would be suboptimal again.)

frmetz avatar Sep 28 '20 12:09 frmetz

Thanks @frmetz for the issue! When timing JAX functions, you should always call .block_until_ready() on the result of the function prior to measuring runtimes. JAX asynchronously dispatches instructions, and block_until_ready() forces the computation to finish before executing the next python statement. Otherwise, you are only measuring dispatch times.

That said, the timings for contract I get for numpy and JAX are: numpy: ~350ms

JAX:

  • first run of jitted contract: ~20s
  • subsequent runs of jitted contract: ~700µs The large tracing overhead of the first run is likely due to a python for loop deeper down in the code that is being unrolled upon calling jit. Usually one would directly use lax.while_loop, lax.scan or lax.fori_loop in this case to reduce the compilation time, tho in this case this is tricky.

What you should do in this case is to identify the computationally heavy parts of your code, wrap it into functions and jit them at the beginning of your code. At the first invocation of the function you'll have to suffer the tracing overhead, but subsequent calls will be substantially faster (modulo potential retracing if certain proerties of the args change). Hope this helps!

mganahl avatar Sep 28 '20 14:09 mganahl

here is also a slightly changed version of your code where JAX and numpy perform about equal (at least for num_samples=1)

import time
sites = 784
d_phys = 2
d_bond = 3
n_samples = 1
tensors =[randn(1, d_phys, d_bond)] +  [randn(d_bond, d_phys, d_bond) for _ in range(sites-2)] + [randn(d_bond, d_phys, 1)]
samples = [randn(n_samples, d_phys) for _ in range(sites)]
def contract_2(mps_tensors, sample_tensors):
    res = tn.ncon([mps_tensors[0], sample_tensors[0]],[[-2,1,-3],[-1,1]])
    for n, tensor in enumerate(mps_tensors[1:]):
        res = tn.ncon([res, tensor, sample_tensors[n]],[[-1,-2,1],[1,2,-3],[-1,2]])
    return res

contract_2_jit = jit(contract_2)

tn.set_default_backend('numpy')
print('numpy runtimes')
%timeit contract_martin(tensors, samples)

tn.set_default_backend('jax')
inputs = [jnp.array(t) for t in tensors], [jnp.array(t) for t in samples]

t1 = time.time()
contract_2_jit(*inputs).block_until_ready()
print('tracing contract_2_jit: ',time.time() - t1)

print('jitted contract_2_jit')
%timeit contract_2_jit(*inputs).block_until_ready()


print('unjitted contract_2')
%timeit contract_2(*inputs).block_until_ready()

mganahl avatar Sep 28 '20 15:09 mganahl

Thanks @mganahl! This was very helpful. Your code example actually accidentally solved another issue I had :)

Just to be sure that I am understanding it correctly: The for loops that are likely to increase the compilation times in my example are due to loops inside the TensorNetwork functions I call, right? So there is nothing really I can do to further speed up this part of the computation? Thanks again!

frmetz avatar Sep 30 '20 12:09 frmetz