trax
trax copied to clipboard
[Bug] ReversibleSelect causes error when training model
Description
ReversibleSelect
seems to mess up JAX's backtracing/JIT compilation. In the code provided below, we define a simple (non reversible) model which (1) splits the inputs; (2) does something to one input; (3) swaps the two inputs; (4) merges the inputs. Step 3 uses ReversibleSelect
. The resulting model can be initialized and called, but will cause errors when attempting to train it with TrainTask
. Curiously enough:
- If one replaces
ReversibleSelect
bySelect
, there is no error. - If one replaces
ReversibleSelect
by a pure function which manually swaps the inputs, there is no error. By "pure function" we mean something liketl.Fn("Swap", lambda a, b: (b, a), n_out=2)
. - If one interchanges steps 2 and 3, i.e., swap the two inputs, and then do something to one of the inputs, there is no error.
Finally, in case this is important, the code was run on a machine without a GPU or TPU.
Environment information
OS: Linux 4.9.0-13-amd64 #1 SMP Debian 4.9.228-1 (2020-07-05) x86_64 GNU/Linux
$ pip freeze | grep trax
trax==1.3.6
$ pip freeze | grep tensor
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.0
tensorflow-addons==0.11.2
tensorflow-data-validation==0.22.2
tensorflow-datasets==2.0.0
tensorflow-enterprise-addons @ file:///opt/conda/conda-bld/dlenv-tf-2-1-cpu_1598328292311/work/tensorflow_enterprise_addons-0.0.0-py3-none-any.whl
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-io==0.11.0
tensorflow-metadata==0.22.2
tensorflow-model-analysis==0.22.2
tensorflow-probability==0.7.0
tensorflow-serving-api==2.1.0
tensorflow-text==2.3.0
tensorflow-transform==0.22.0
$ pip freeze | grep jax
jax==0.1.75
jaxlib==0.1.52
$ python -V
Python 3.7.8
For bugs: reproduction and error logs
To reproduce, run the following code
import trax
import trax.data as td
import trax.layers as tl
import trax.supervised.training as tt
from trax.fastmath import numpy as jnp
trax.fastmath.use_backend('jax')
split_stack = tl.Fn("Split Stack", lambda x: jnp.split(x, 2), n_out=2)
merge_stack = tl.Fn("Merge Stack", lambda x1, x2: jnp.concatenate([x1, x2]), n_out=1)
inputs_size = 20
def input_stream0(_=None):
while True:
yield (jnp.zeros((inputs_size,)), jnp.zeros((inputs_size,)))
# If one replaces ReversibleSelect by Select, or by a manual swap of inputs, everything works!
model = tl.Serial(split_stack, tl.Dense(inputs_size//2), tl.ReversibleSelect([1, 0]), merge_stack)
in_stream = lambda: td.Serial(input_stream0, td.AddLossWeights())()
train_task = tt.TrainTask(
labeled_data = in_stream(),
loss_layer = tl.L2Loss(),
optimizer = trax.optimizers.Adam(0.01))
training_loop = tt.Loop(model, train_task)
training_loop.run(1)
Error logs:
Below is the full error log. I don't think it's important, but in case it is, there is also a warning about no GPU/TPU, some information about tensorflow when trax is first loaded, and a warning about the missing output_dir
parameter in Loop
.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-1-2a3851881ca5> in <module>
24
25 training_loop = tt.Loop(model, train_task)
---> 26 training_loop.run(1)
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in run(self, n_steps)
337 task_index = self._which_task(self._step)
338 task_changed = task_index != prev_task_index
--> 339 loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
340
341 # optimizer_metrics and loss are replicated on self.n_devices, a few
/opt/conda/lib/python3.7/site-packages/trax/supervised/training.py in _run_one_step(self, task_index, task_changed)
450 trainer.accelerated_loss_layer.replicate_weights(model.weights)
451 trainer.accelerated_loss_layer.replicate_state(model.state)
--> 452 return trainer.one_step(batch, rng, step=step, learning_rate=learning_rate)
453
454 def _log_training_progress(self, task, total_loss, n_steps, elapsed_time,
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in one_step(self, batch, rng, step, learning_rate)
129 # NOTE: stats is a replicated dictionary of key to jnp arrays.
130 (new_weights, new_slots), new_state, stats = self._accelerated_update_fn(
--> 131 (weights, self._slots), step, self._opt_params, batch, state, rng)
132
133 if logging.vlog_is_on(1) and ((step & step - 1) == 0):
/opt/conda/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
169 flat_fun, out_tree = flatten_fun(f, in_tree)
170 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 171 name=flat_fun.__name__, donated_invars=donated_invars)
172 return tree_unflatten(out_tree(), out)
173
/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1132
1133 def bind(self, fun, *args, **params):
-> 1134 return call_bind(self, fun, *args, **params)
1135
1136 def process(self, trace, fun, tracers, params):
/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1121 if top_trace is None:
1122 with new_sublevel():
-> 1123 outs = primitive.impl(fun, *args, **params)
1124 else:
1125 tracers = map(top_trace.full_raise, args)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
525 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
526 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 527 *unsafe_map(arg_spec, args))
528 try:
529 return compiled_fun(*args)
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
222 fun.populate_stores(stores)
223 else:
--> 224 ans = call(fun, *args)
225 cache[key] = (ans, fun.stores)
226 return ans
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
596 pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
597 jaxpr, pvals, consts = pe.trace_to_jaxpr(
--> 598 fun, pvals, instantiate=False, stage_out=True, bottom=True)
599 map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
600 jaxpr = apply_outfeed_rewriter(jaxpr)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
421 with core.new_master(trace_type, bottom=bottom) as master:
422 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 423 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
424 assert not env
425 del master
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
/opt/conda/lib/python3.7/site-packages/trax/optimizers/trainer.py in single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng)
173 weights, slots = weights_and_slots
174 (loss, state), gradients = forward_and_backward_fn(
--> 175 batch, weights, state, rng)
176 weights, slots, stats = optimizer.tree_update(
177 step, gradients, weights, slots, opt_params)
/opt/conda/lib/python3.7/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
491 dtype = dtypes.result_type(ans)
492 tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
--> 493 g = vjp_py(np.ones((), dtype=dtype))
494 g = g[0] if isinstance(argnums, int) else g
495 if not has_aux:
/opt/conda/lib/python3.7/site-packages/jax/api.py in _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args)
1458 "match type of corresponding primal output ({})")
1459 raise TypeError(msg.format(_dtype(a), dtype))
-> 1460 ans = fun(*args)
1461 return tree_unflatten(out_tree, ans)
1462
/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in unbound_vjp(pvals, jaxpr, consts, *cts)
115 cts = tuple(map(ignore_consts, cts, pvals))
116 dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
--> 117 arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
118 return map(instantiate_zeros, arg_cts)
119
/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
204 else:
205 cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
--> 206 **eqn.params)
207 cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
208 # FIXME: Some invars correspond to primals!
/opt/conda/lib/python3.7/site-packages/jax/interpreters/ad.py in _custom_lin_transpose(cts_out, num_res, bwd, avals_out, *invals)
609 res, _ = split_list(invals, [num_res])
610 cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
--> 611 cts_in = bwd.call_wrapped(*res, *cts_out)
612 cts_in_flat, _ = tree_flatten(cts_in) # already checked tree structure
613 return [None] * num_res + cts_in_flat
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(***failed resolving arguments***)
152 while stack:
153 gen, out_store = stack.pop()
--> 154 ans = gen.send(ans)
155 if out_store is not None:
156 ans, side = ans
/opt/conda/lib/python3.7/site-packages/jax/custom_derivatives.py in _flatten_bwd(in_tree, out_trees, *args)
510 "number of arguments to the primal function, but got VJP output "
511 "structure {} for primal input structure {}.")
--> 512 raise TypeError(msg.format(in_tree2, in_tree)) from None
513 yield cts_in
514
TypeError: Custom VJP rule must produce an output with the same container (pytree) structure as the args tuple of the primal function, and in particular must produce a tuple of length equal to the number of arguments to the primal function, but got VJP output structure PyTreeDef(tuple, [PyTreeDef(None, []),PyTreeDef(None, []),PyTreeDef(tuple, [*,*]),PyTreeDef(tuple, [])]) for primal input structure PyTreeDef(tuple, [PyTreeDef(tuple, []),*,PyTreeDef(tuple, [*,*]),PyTreeDef(tuple, [])]).
Update
It appears that any model which contains a ReversibleLayer
(or subclass of) will produce such an error. The following code produces the same errors. The precise error depends on whether:
- ...my class
DoSomething
is declared a subclass oftl.Layer
ortl.ReversibleLayer
- ...my class
PureReversible
is declared a subclass oftl.Layer
ortl.ReversibleLayer
- ...the whole model is assembled with
tl.Serial
ortl.ReversibleSerial
.
If all 3 use the non-reversible options, everything works. If any of the 3 use the reversible option, it produces an error similar to the above.
import trax
import trax.data as td
import trax.layers as tl
import trax.supervised.training as tt
from trax.fastmath import numpy as jnp
class DoSomething(tl.ReversibleLayer):
def __init__(self):
super().__init__(n_in=2, n_out=2)
self.l = tl.Dense(10)
self._sublayers = (self.l,)
def forward(self, x):
x1, x2 = x
return x1 + self.l(x2), x2
def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
y1, y2 = output
return y1 - self.l(y2), y2
def init_weights_and_state(self, sig):
w, s = self.l.init(sig[0])
self.weights = (w,)
self.state = (s,)
class PureReversible(tl.ReversibleLayer):
def __init__(self, forw, backw, n_in, n_out):
self.forw = forw
self.backw = backw
super().__init__(n_in=n_in, n_out=n_out)
def forward(self, x):
return self.forw(x)
def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
return self.backw(output)
split = tl.Fn("Split Stack", lambda x: jnp.split(x, 2), n_out=2)
merge = tl.Fn("Merge Stack", lambda x1, x2: jnp.concatenate([x1, x2]), n_out=1)
rev_split = PureReversible(split, merge, 1, 2)
rev_merge = PureReversible(merge, split, 2, 1)
def input_stream0(_=None):
while True:
yield (jnp.zeros((20,)), jnp.zeros((20,)))
model = tl.ReversibleSerial(rev_split, DoSomething(), rev_merge)
train_task = tt.TrainTask(
labeled_data = td.Serial(input_stream0, td.AddLossWeights())(),
loss_layer = tl.L2Loss(),
optimizer = trax.optimizers.Adam(0.01))
training_loop = tt.Loop(model, train_task)
training_loop.run(1)
If there is any known workaround that does not involve completely dropping reversible layers, we would really appreciate a temporary solution, as we depend on the memory savings incurred by reversible networks. Thank you!
Did you try to install Trax
from the master branch?
No, I got it from pip. However in the end I managed to make it work by passing use_memory_efficient_trainer=True
as an argument to Loop
. (I'm not sure if this should be closed so I'm leaving it as is, but I'm satisfied with the solution I found)