jax icon indicating copy to clipboard operation
jax copied to clipboard

Gradients with `odeint` slow on GPU

Open spenrich opened this issue 5 years ago • 5 comments

The following MWE trains a simple neural ODE model with gradient descent to match a 2-D dynamical system (Van der Pol oscillator) with sampled data along a single trajectory. Each iteration of the training loop runs slowly on my GPU when compared to running everything on my CPU (roughly estimated with tqdm at 17 iterations/sec on GPU vs. upwards of 800 iterations/sec on CPU).

Any first impressions about what might be going on? I can look into doing better profiling if need be.

Versions: jax 0.2.6, jaxlib 0.1.57+cuda102, cuda 10.2

import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

# Uncomment this line to force using the CPU
# jax.config.update('jax_platform_name', 'cpu')

# Some utilities for dealing with PyTrees of parameters
def tree_axpy(a, x_tree, y_tree):
    """Compute `y = a*x` for two PyTrees `(x, y)` and a scalar `a`."""
    ax = jax.tree_util.tree_map(lambda x: a * x, x_tree)
    axpy = jax.tree_util.tree_multimap(lambda x, y: x + y, ax, y_tree)
    return axpy

def tree_normsq(x_tree):
    """Compute sum of squared norms across a PyTree."""
    normsq = jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2), x_tree, 0.)
    return normsq

# Define true ODE, our approximator, and the loss function
def f(x, t):
    """Compute state derivative of a Van der Pol oscillator."""
    mu = 1.
    dx = jnp.hstack([
        mu*(x[0] - x[0]**3/3 - x[1]),
        x[0]/mu
    ])
    return dx

def f_est(x, t, params):
    """Estimate state derivative with a two-layer neural network."""
    W = params['W']
    b = params['b']
    y = W[0]@x + b[0]
    y = W[1]@jnp.tanh(y) + b[1]
    return y

def loss(params, x, t, reg_coeff):
    """Compute the sum of squared losses along a queried trajectory."""
    x_hat = odeint(f_est, x[0], t, params)
    error = jnp.sum((x - x_hat)**2)
    loss_value = error + reg_coeff*tree_normsq(params)
    return loss_value

# Generate data along a trajectory of the true system
x0 = jnp.array([1., 0.])
t0, tf = (0., 5.)
dt = 0.1
num_steps = int((tf - t0) / dt) + 1
t = jnp.linspace(t0, tf, num_steps)
x = odeint(f, x0, t)

# Initialize neural network parameters
n = 2
hdim = 32  # size of hidden layer
key = jax.random.PRNGKey(0)
params = {
    'W': [
        0.1*jax.random.normal(key, (hdim, n)),
        0.1*jax.random.normal(key, (n, hdim)),
    ],
    'b': [
        0.1*jax.random.normal(key, (hdim,)),
        0.1*jax.random.normal(key, (n,)),
    ]
}

# Training
loss_buffer = []
step_size = 1e-4
reg_coeff = 1e-6
value_and_grad = jax.jit(jax.value_and_grad(loss))
for _ in tqdm(range(5000)):
    value, grad = value_and_grad(params, x, t, reg_coeff)
    loss_buffer.append(value)
    params = tree_axpy(-step_size, grad, params)  # gradient descent step
print('Regularized fit loss:', loss_buffer[-1])

# Plotting (optional)
try:
    import matplotlib.pyplot as plt

    x_est = odeint(f_est, x0, t, params)
    
    fig, axes = plt.subplots(1, 2, figsize=(15,5))
    axes[0].plot(x[:,0], x[:,1], '--x')
    axes[0].plot(x_est[:,0], x_est[:,1], '-')
    axes[1].plot(loss_buffer)
    axes[1].set_yscale('log')
    plt.show()
except ImportError:
    print('Package `matplotlib` not found! Skipping plots.')

spenrich avatar Nov 24 '20 18:11 spenrich

The short answer is that unfortunately at this time XLA GPU is not great at code generation for tight loops like those in odeint. The body of the while_loop is compiled into one or more GPU kernels, which has significant launch overhead because control flow goes back to the CPU in each iteration.

shoyer avatar Nov 24 '20 19:11 shoyer

@shoyer does anyone work on improving while_loop?

awav avatar Dec 08 '20 16:12 awav

Yes, there are several ongoing streams of work to improve while_loop.

shoyer avatar Dec 08 '20 17:12 shoyer

@shoyer, apologies for unrelated with the topic questions. Could you share the links to PRs, branches to ongoing work if it's publicly available?

awav avatar Dec 09 '20 20:12 awav

Sorry, I don't have any details that I can share at this time.

On Wed, Dec 9, 2020 at 12:37 PM Artem Artemev [email protected] wrote:

@shoyer https://github.com/shoyer, apologies for unrelated with the topic questions. Could you share the links to PRs, branches to ongoing work if it's publicly available?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/5006#issuecomment-742033766, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVSQQ6RAS2NDDY2AH5DST7NXDANCNFSM4UBHJBMQ .

shoyer avatar Dec 09 '20 21:12 shoyer