uncertainty-baselines icon indicating copy to clipboard operation
uncertainty-baselines copied to clipboard

RuntimeError when running baselines/imagenet/sngp.py

Open batzner opened this issue 4 years ago • 5 comments

Dear uncertainty-baseline authors,

I am trying to run the SNGP training on ImageNet using uncertainty-baselines/baselines/imagenet/sngp.py.

It errors during the execution of the first training step with the following message:

 RuntimeError: `merge_call` called while defining a new graph or a tf.function.
This can often happen if the function `fn` passed to `strategy.run()` 
contains a nested `@tf.function`, and the nested `@tf.function` contains 
a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients),
or if the function `fn` uses a control flow statement which contains a synchronization 
point in the body. Such behaviors are not yet supported. Instead, please avoid 
nested `tf.function`s or control flow statements that may potentially cross a
synchronization boundary, for example, wrap the `fn` passed to `strategy.run` 
or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`

This is the stack trace:

RuntimeError: in user code:

    .../lib/uncertainty-baselines/baselines/imagenet/sngp_tmp.py:290 step_fn  *
        model.layers[-1].reset_covariance_matrix()
    ../edward2/edward2/tensorflow/layers/random_feature.py:219 reset_covariance_matrix  *
        self._gp_cov_layer.reset_precision_matrix()
    ../edward2/edward2/tensorflow/layers/random_feature.py:363 reset_precision_matrix  *
        precision_matrix_reset_op = self.precision_matrix.assign(
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:685 assign  **
        return values_util.on_write_assign(self, value, use_locking=use_locking,
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values_util.py:33 on_write_assign
        return var._update(  # pylint: disable=protected-access
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:827 _update
        return self._update_replica(update_fn, value, **kwargs)
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:897 _update_replica
        return _on_write_update_replica(self, update_fn, value, **kwargs)
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:71 _on_write_update_replica
        return ds_context.get_replica_context().merge_call(
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2715 merge_call
        return self._merge_call(merge_fn, args, kwargs)
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py:432 _merge_call
        raise RuntimeError(

It seems like the self.precision_matrix.assign call in edward2/edward2/tensorflow/layers/random_feature.py causes this error, because it is executed inside the strategy.run call of a tf.function.

What can I do to fix this?

batzner avatar Apr 15 '21 11:04 batzner

@jereliu

Hi @batzner! Were you able to resolve this?

dustinvtran avatar Apr 28 '21 18:04 dustinvtran

Unfortunately, I was not able to resolve this. I saw in GitHub issues in other repos that people who got the same error message were able to resolve it by specifying synchronization=tf.VariableSynchronization.ON_READ in the self.add_weight(...) call (see below).

But I am not sure whether the on-read-synchronization is the correct behavior in this case.

self.precision_matrix = (
    self.add_weight(
        name='gp_precision_matrix',
        shape=(gp_feature_dim, gp_feature_dim),
        dtype=self.dtype,
        initializer=tf.keras.initializers.Identity(self.ridge_penalty),
        trainable=False,
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA))

(from https://github.com/google/edward2/blob/807bd74d93c607a5a4030c4ef7debecf89f8b6ab/edward2/tensorflow/layers/random_feature.py#L337)

batzner avatar Apr 28 '21 19:04 batzner

Hey! Sorry for the delay.

What was the command you were using to run the code, and what was the GPU/version of CUDA/version of TF? Also, have you been able to reproduce this at the current HEAD?

znado avatar May 11 '21 00:05 znado

I get the exact same error. Fresh install (nightly version of tf - 9/7/21), CUDA 11.2.

gpleiss avatar Sep 07 '21 14:09 gpleiss

I also get this error with TF 2.7

psiden avatar Jan 11 '22 12:01 psiden