lapack
lapack copied to clipboard
LAPACK Inconsistent across multiple different operating systems and devices
Description I have a deterministic program that uses jax, and is heavy on linear algebra operations.
I ran this code on CPU, using three different CPUs. Two MacOs Systems (one on Sequoia (M1 Pro), other on Sonoma (M2)) and one on a linux system.
All three systems output different results for the same output, however they output that output deterministically.
Minimal Reproducible example
import jax
import optax
import flax.linen as nn
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
variables = jnp.array([0.1, -3 * jnp.pi / 2])
class RNN(nn.Module):
hidden_size: int
output_size: int
@nn.compact
def __call__(self, input, hidden_state):
gru_cell = nn.GRUCell(features=self.hidden_size)
new_hidden_state, _ = gru_cell(hidden_state, input)
output = nn.Dense(features=self.output_size)(new_hidden_state)
return output, new_hidden_state
def _optimize(
loss_fn,
init_params,
max_iter,
learning_rate,
):
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(init_params)
@jax.jit
def step(params, state):
grads = jax.grad(loss_fn)(params)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_state
params = init_params
for iter_idx in range(max_iter):
params, opt_state = step(params, opt_state)
return params, iter_idx + 1
def fun(gamma, delta):
op = jnp.array([[0, -1j], [1j, 0]])
angle = (gamma * op) + delta / 2
return (jax.scipy.linalg.expm(1j * angle) + jax.scipy.linalg.expm(-1j * angle)) / 2
def loss(params):
rnn = RNN(hidden_size=10, output_size=2)
input = variables
hidden_state = jnp.zeros((10,))
output, _ = rnn.apply({'params': params}, input, hidden_state)
params_out = output
return jnp.real(jnp.trace(fun(params_out[0], params_out[1])))
if __name__ == "__main__":
rng = jax.random.PRNGKey(0)
rnn = RNN(hidden_size=10, output_size=2)
input = variables
hidden_state = jnp.zeros((10,))
params = rnn.init(rng, input, hidden_state)['params']
max_iter = 100
learning_rate = 0.01
convergence_threshold = 1e-6
optimized_params, num_iterations = _optimize(
loss,
params,
max_iter,
learning_rate,
)
final_loss = loss(optimized_params)
print("Final Loss:", final_loss)
This outputs on a macos system:
-1.9979573829398634
and on a linux system:
-1.9979573808129485
Differing in the last 8 digits, and I suppose given a much larger complicated system this difference can be quite large.
In Machine Learning Applications those two difference in numerical outputs can significantly lead to convergence at difference minima, if the convergence is not so straight forwards
Checklist
- [x] I've included a minimal example to reproduce the issue
- [ ] I'd be willing to make a PR to solve this issue
This looks a lot like you are running your calculations in single precision, which only gives 7-8 decimals
jax.config.update("jax_enable_x64", True)
I think this line makes it use double precision though, or?
I do not know jax, but https://github.com/jax-ml/jax/issues/22688#issuecomment-2253298306 leads me to believe that this setting only ensures that 64bit values do not get truncated to 32bit automatically. now numpy/scipy work with 64bit precision as far as I know, so this may be a red herring, but perhaps some of your variables are the wrong dtype ?
a different but related question is if you are certain that your three installations all use Reference-LAPACK as the linear algebra backend, rather than Apple's Accelerate or OpenBLAS
a different but related question is if you are certain that your three installations all use Reference-LAPACK as the linear algebra backend, rather than Apple's Accelerate or OpenBLAS
No, actually I think one uses apple's accelerate and the other openBlas.
I do not know jax, but https://github.com/jax-ml/jax/issues/22688#issuecomment-2253298306 leads me to believe that this setting only ensures that 64bit values do not get truncated to 32bit automatically. now numpy/scipy work with 64bit precision as far as I know, so this may be a red herring, but perhaps some of your variables are the wrong dtype ?
I can't completely confirm nor deny this, I might need to revise all my datatypes once more, I was however led to suspect that the issue may be due to different LAPACK implementations on different systems from here, since the code is exactly the same on both systems.
Eitherway, if some of my variables are indeed the wrong type shouldn't they still lead to the same (wrong) numeric on both systems?
No, actually I think one uses apple's accelerate and the other openBlas.
Then I think it is technically not accurate to claim this is a bug in Reference-LAPACK - rather an effect of using optimized reimplementations of the functions defined by the reference implementation (in particular, of the underlying BLAS).
While the Fortran source of the Reference is already subject to the optimization capabilites (and whims) of whatever compiler brand and version on each system, Apple M1/M2 hardware includes a poorly documented coprocessor hardware for matrix operations, and OpenBLAS mostly uses hand-coded assembly that makes use of advanced instructions that combine individual operations without the intermediate round-and-store of a "naive" implementation, potentially leading to subtle differences in accuracy. This is where I think the section on permissible deviations would apply in the LAPACK FAQ (although there it is only discussed in the narrow context of testsuite errors) https://www.netlib.org/lapack/faq.html#_how_do_i_interpret_lapack_testing_failures
I can't completely confirm nor deny this, I might need to revise all my datatypes once more, I was however led to suspect that the issue may be due to different LAPACK implementations on different systems from here, since the code is exactly the same on both systems.
If your code is correct, this could be a bug in either OpenBLAS or Accelerate if either is using single precision where it should not, but the Reference LAPACK project has no control over implementations. It might be helpful to see results with Reference-LAPACK on either system, but on the other hand it might just result in another three slightly different numbers :)
Eitherway, if some of my variables are indeed the wrong type shouldn't they still lead to the same (wrong) numeric on both systems?
I think that if a crucial variable in your code is inadvertently defined in single precision, the digits beyond the guaranteed 7th or 8th decimal are essentially undefined (even if they match expectations)
Okay, thank you for your insights, since I am currently not sure this is a problem in Reference LAPACK I will close this issue.