model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

'Pruning in Keras' Example: Unable to fine-tune pruned model on GPU (TF 2.9.2)

Open swapnilsayansaha opened this issue 3 years ago • 5 comments

Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.

Describe the bug Most likely, sparse operators such as PruneLowMagntiude cannot be loaded and operated on on a GPU.

System information

TensorFlow version (installed from source or binary): Binary, 2.9.2, CUDA: 11.6

GPU: Nvidia RTX 3090 24 GB

OS: Ubuntu 20.04

TensorFlow Model Optimization version (installed from source or binary): Binary, 0.7.3

Python version: 3.8

Describe the expected behavior and the current behavior Issue described here: https://github.com/tensorflow/tensorflow/issues/58499 https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras The example code for pruning should work as it is out of the box. However, fine-tuning the pruned model doesn't work on GPU. I made a workaround to solve it by forcing the fine-tuning of the pruned model on CPU:

with tf.device('/cpu:0'):
   model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)

The unpruned model can train fine on the GPU, it's not a problem with CUDA drivers, so please do not suggest reconfiguring a new conda/venv environment

The following error occurs without the with tf.device('/cpu:0')::

UnknownError                              Traceback (most recent call last)
Cell In [5], line 8
      1 logdir = tempfile.mkdtemp()
      3 callbacks = [
      4   tfmot.sparsity.keras.UpdatePruningStep(),
      5   tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
      6 ]
----> 8 model_for_pruning.fit(train_images, train_labels,
      9                   batch_size=batch_size, epochs=epochs, validation_split=validation_split,
     10                   callbacks=callbacks)

File ~/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     65 except Exception as e:  # pylint: disable=broad-except
     66   filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67   raise e.with_traceback(filtered_tb) from None
     68 finally:
     69   del filtered_tb

File ~/swapnil_debug_2/lib/python3.8/site-packages/tensorflow/python/eager/execute.py:54, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     52 try:
     53   ctx.ensure_initialized()
---> 54   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     55                                       inputs, attrs, num_outputs)
     56 except core._NotOkStatusException as e:
     57   if name is not None:

UnknownError: Graph execution error:

Detected at node 'sequential/prune_low_magnitude_conv2d/FloorMod' defined at (most recent call last):
    File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/traitlets/config/application.py", line 982, in launch_instance
      app.start()
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_2426111/471826281.py", line 8, in <module>
      model_for_pruning.fit(train_images, train_labels,
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1409, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1051, in train_function
      return step_function(self, iterator)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1040, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 1030, in run_step
      outputs = model.train_step(data)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 889, in train_step
      y_pred = self(x, training=True)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/training.py", line 490, in __call__
      return super().__call__(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/sequential.py", line 374, in call
      return super(Sequential, self).call(inputs, training=training, mask=mask)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/functional.py", line 458, in call
      return self._run_internal_graph(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/functional.py", line 596, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 280, in call
      update_mask = utils.smart_cond(training, add_update, no_op)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 50, in smart_cond
      if isinstance(pred, variables.Variable):
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 54, in smart_cond
      pred, true_fn=true_fn, false_fn=false_fn, name=name)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 268, in add_update
      with tf.control_dependencies(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 310, in conditional_mask_update
      return tf.distribute.get_replica_context().merge_call(
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 307, in mask_update_distributed
      return tf.cond(maybe_update_masks(), update_distributed, no_update)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 260, in maybe_update_masks
      if self._sparsity_m_by_n:
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py", line 264, in maybe_update_masks
      return self._pruning_schedule(self._step_fn())[0]
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 246, in __call__
      sparsity)
    File "/home/nesl/swapnil_debug_2/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule.py", line 61, in _should_prune_in_step
      is_pruning_turn = tf.math.equal(
Node: 'sequential/prune_low_magnitude_conv2d/FloorMod'
JIT compilation failed.
	 [[{{node sequential/prune_low_magnitude_conv2d/FloorMod}}]] [Op:__inference_train_function_34086]

Code to reproduce the issue https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras

Screenshots If applicable, add screenshots to help explain your problem.

Additional context The problem isn't too serious as I can train the unpruned model on the GPU for for example 200 epochs, save its weights, load it, add the necessary code to prune the model, then fine-tune it for example for 10 epochs on the CPU. However, it's worth looking into why the fine-tuning cannot happen on the GPU.

swapnilsayansaha avatar Nov 10 '22 16:11 swapnilsayansaha

Hi, it seems like an issue of floormod op on GPU rather than Pruning API's issue. It is weird since the similar bug is already fixed an year ago - https://github.com/tensorflow/tensorflow/issues/46887

Could you double check your tensorflow version? If it exists in recent tensorflow version, we may need to reopen the above issue.

rino20 avatar Nov 11 '22 15:11 rino20

Tf version is tf2.9.2 (GPU)

swapnilsayansaha avatar Nov 11 '22 16:11 swapnilsayansaha

similar bug on win10 tf2.10.0 with floormod

back2yes avatar Jul 31 '23 14:07 back2yes

Having the same problem on RTX 3090 with tensorflow 2.10. Can't even run PQAT because of the issue with pruning using GPU

puelon avatar Sep 01 '23 13:09 puelon

I had the same problem. So I ran the following and got an error that libdevice.10.bc was not found.

@tf.function(jit_compile=True)
def floormod(a, b):
  return tf.math.floormod(a, b)

floormod(tf.constant(1.), tf.constant(1.)) 
tensorflow.python.framework.errors_impl.InternalError: libdevice not found at ./libdevice.10.bc [Op:__inference_floormod_49]

I added the following to the top of the program and it worked.

import os
os.environ["XLA_FLAGS"]='--xla_gpu_cuda_data_dir=/path/to/cuda'

I hope this helps.

gyojir avatar Sep 07 '23 20:09 gyojir