DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] in `runtime/zero/partitioned_param_coordinator.py` -- Parameter not in `self.__inflight_param_registry` but has ds_status = `ZeroParamStatus.INFLIGHT`

Open ckgresla opened this issue 10 months ago • 2 comments

Describe the bug A clear and concise description of what the bug is. The issue I am facing is that of the assertion on L316 of partitioned_param_coordinator.py is being raised by parameters which are marked ZeroParamStatus.INFLIGHT but are not members of the PartitionedParameterCoordinator's self.__inflight_param_registry. This assertion gets triggered and causes a training job to fail at the first pass over the validation split of a related script. When the model forwards a batch inside of the with torch.no_grad() context whilst being set to .eval() mode.

I tried running my fine-tuning script with this assertion commented out, it seems to have gone without a hitch -- is there a reason why params here might be set to INFLIGHT? Happy to provide more context as needed.

Additionally I experimented with holding all else constant and setting the values of overlap_comm & pin_memory to false in the related ds_config, both settings had the same error of the assertion being raised.

Expected behavior Parameters which are inflight to be awaited before being used.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch']
torch version .................... 2.1.0+cu121
deepspeed install path ........... ['/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.14.0, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.2
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 125.77 GB

System info:

  • OS: Ubuntu 22.04.4 LTS
  • 2x RTX A6000s (~48GB of VRAM per card)
  • Python version
  • Any other relevant info about your setup

Launcher context using mpirun to launch the job on a node.

Docker context N/A

Additional context the used ds_config:

{
    "activation_checkpointing": {
        "partition_activations": true,
        "cpu_checkpointing": true
    },
    "fp16": {
        "enabled": true,
        "autocast": false,
        "initial_scale_power": 8
    },
    "gradient_accumulation_steps": 4,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.0003,
            "betas": [0.8, 0.999],
            "eps": 1e-8,
            "weight_decay": 3e-7
        }
    },
    "train_micro_batch_size_per_gpu": 1,
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "stage3_gather_16bit_weights_on_model_save": true
    }
}

Stacktrace

Traceback (most recent call last):
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
Traceback (most recent call last):
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/mpi4py/__main__.py", line 7, in <module>
    main()
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/mpi4py/run.py", line 230, in main
    run_command_line(args)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/mpi4py/run.py", line 47, in run_command_line
    run_path(sys.argv[0], run_name='__main__')
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/mpi4py/__main__.py", line 7, in <module>
    main()
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/mpi4py/run.py", line 230, in main
    run_command_line(args)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "ds_fine_tune.py", line 690, in <module>
    main(args)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/mpi4py/run.py", line 47, in run_command_line
    run_path(sys.argv[0], run_name='__main__')
  File "ds_fine_tune.py", line 513, in main
    logits = forward_batch(engine, batch, eval_mode=True)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "ds_fine_tune.py", line 190, in forward_batch
    logits = engine(input_ids, attention_mask)["logits"]
  File "/opt/miniconda3/envs/ckg/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "ds_fine_tune.py", line 690, in <module>
    main(args)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "ds_fine_tune.py", line 513, in main
    logits = forward_batch(engine, batch, eval_mode=True)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "ds_fine_tune.py", line 190, in forward_batch
    logits = engine(input_ids, attention_mask)["logits"]
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1852, in forward
    loss = self.module(*inputs, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1054, in forward
    outputs = self.model(
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1852, in forward
    loss = self.module(*inputs, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 893, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1054, in forward
    outputs = self.model(
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1557, in _call_impl
    args_result = hook(self, args)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 893, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 278, in _pre_forward_module_hook
    self.pre_sub_module_forward_function(module)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1557, in _call_impl
    args_result = hook(self, args)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 452, in pre_sub_module_forward_function
    param_coordinator.fetch_sub_module(sub_module, forward=True)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 278, in _pre_forward_module_hook
    self.pre_sub_module_forward_function(module)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 452, in pre_sub_module_forward_function
    param_coordinator.fetch_sub_module(sub_module, forward=True)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 320, in fetch_sub_module
    assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
AssertionError: {'id': 0, 'status': 'INFLIGHT', 'numel': 131072000, 'ds_numel': 131072000, 'shape': (32000, 4096), 'ds_shape': (32000, 4096), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {2}, 'ds_tensor.shape': torch.Size([65536000])}
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/miniconda3/envs/ckg/lib/python3.10/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 320, in fetch_sub_module
    assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
AssertionError: {'id': 0, 'status': 'INFLIGHT', 'numel': 131072000, 'ds_numel': 131072000, 'shape': (32000, 4096), 'ds_shape': (32000, 4096), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {2}, 'ds_tensor.shape': torch.Size([65536000])}

Additional Prints @ Time of Stacktrace being raised:

param in the inflight registry? := False
abouta assert, the param.ds_status is:  ZeroParamStatus.INFLIGHT

ckgresla avatar Apr 08 '24 19:04 ckgresla

Still running into this issue, I would be happy to assist in resolving/debugging this – any chance we could get some assistance from the team to help snoop on it?

ckgresla avatar Jul 25 '24 16:07 ckgresla