DeepSpeedExamples
DeepSpeedExamples copied to clipboard
DeepSpeed-Chat: prefetch of layers during reward model forward pass leads to error during sample generation
When running step 3 with ZERO stage 3 enabled for both the actor and critic models, I get the following error (line numbers may be offset due to debug statements I've added):
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
main()
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 438, in main
out = trainer.generate_experience(prompts)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 103, in generate_experience
seq = self._generate_sequence(prompts)
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
seq = self.actor_model.module.generate(prompts,
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
self.fuse_lora_weight()
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 128, in _fuse_lora
weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
RuntimeError: The size of tensor a (8192) must match the size of tensor b (2048) at non-singleton dimension 1
This happens because the weight.data
shape does not match the tensor shape resulting from the lora matmul operation.
I am using a system with 4x 16GB V100 GPUs per node with DeepSpeed 0.9.1. I trained a 1.3b-param model in step 1 and 350m-param model in step 2.
My step 3 run command launches 4 processes on one node, binding one process per GPU:
cd training/step3_rlhf_finetuning
OUTPUT=${OUTPUTDIR}/step3-models/1.3b
mkdir -p $OUTPUT
ACTOR_MODEL_PATH=${OUTPUTDIR}/actor-models/1.3b
CRITIC_MODEL_PATH=${OUTPUTDIR}/reward-models/1.3b
ACTOR_ZERO_STAGE=3
CRITIC_ZERO_STAGE=3
jsrun -r 1 --tasks_per_rs 4 -c ALL_CPUS -g ALL_GPUS python3 main.py \
--per_device_train_batch_size 4 \
--per_device_mini_train_batch_size 4 \
--inference_tp_size 1 \
--max_answer_seq_len 256 \
--max_prompt_seq_len 256 \
--actor_model_name_or_path $ACTOR_MODEL_PATH \
--critic_model_name_or_path $CRITIC_MODEL_PATH \
--actor_zero_stage $ACTOR_ZERO_STAGE \
--critic_zero_stage $CRITIC_ZERO_STAGE \
--num_padding_at_beginning 1 \
--gradient_accumulation_steps 1 \
--deepspeed \
--actor_lora_dim 128 \
--enable_hybrid_engine \
--actor_gradient_checkpointing \
--critic_gradient_checkpointing \
--output_dir $OUTPUT
After some debugging, I found that the above error arises because the GatheredParameters context does not gather all layers. If I print the tensor shape for each parameter of each layer immediately after GatheredParameters like so:
https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/hybrid_engine.py#L238
with GatheredParameters(non_active_layers):
if rank == 0:
for layer_id in range(len(self.layer_params)):
for p_id, p in enumerate(self.layer_params[layer_id]):
print("after gather layer_id", layer_id, p_id, p.shape, flush=True)
self._gather_latency = time.time() - self._t0
then I see the following output on the step just before the error:
nonactive all layers 931
after gather layer_id 0 0 torch.Size([0])
after gather layer_id 0 1 torch.Size([0])
after gather layer_id 0 2 torch.Size([0])
after gather layer_id 0 3 torch.Size([0])
after gather layer_id 0 4 torch.Size([0])
after gather layer_id 0 5 torch.Size([8192])
after gather layer_id 0 6 torch.Size([2048, 8192])
after gather layer_id 0 7 torch.Size([2048])
after gather layer_id 0 8 torch.Size([0])
after gather layer_id 0 9 torch.Size([0])
after gather layer_id 0 10 torch.Size([0])
after gather layer_id 0 11 torch.Size([0])
after gather layer_id 0 12 torch.Size([0])
after gather layer_id 0 13 torch.Size([0])
after gather layer_id 0 14 torch.Size([0])
after gather layer_id 0 15 torch.Size([0])
after gather layer_id 1 0 torch.Size([2048])
after gather layer_id 1 1 torch.Size([2048])
after gather layer_id 1 2 torch.Size([2048])
after gather layer_id 1 3 torch.Size([2048])
after gather layer_id 1 4 torch.Size([8192, 2048])
after gather layer_id 1 5 torch.Size([8192])
after gather layer_id 1 6 torch.Size([2048, 8192])
after gather layer_id 1 7 torch.Size([2048])
after gather layer_id 1 8 torch.Size([2048, 2048])
after gather layer_id 1 9 torch.Size([2048])
after gather layer_id 1 10 torch.Size([2048, 2048])
after gather layer_id 1 11 torch.Size([2048])
after gather layer_id 1 12 torch.Size([2048, 2048])
after gather layer_id 1 13 torch.Size([2048])
after gather layer_id 1 14 torch.Size([2048, 2048])
after gather layer_id 1 15 torch.Size([2048])
Note that dimensions of the parameters in layer_id=0 are mostly all zero. On that steps that complete without an error, those parameters have non-zero shapes as shown below. The count of non_active_layers
in 962 below vs 931 above.
nonactive all layers 962
after gather layer_id 0 0 torch.Size([2048])
after gather layer_id 0 1 torch.Size([2048])
after gather layer_id 0 2 torch.Size([2048])
after gather layer_id 0 3 torch.Size([2048])
after gather layer_id 0 4 torch.Size([8192, 2048])
after gather layer_id 0 5 torch.Size([8192])
after gather layer_id 0 6 torch.Size([2048, 8192])
after gather layer_id 0 7 torch.Size([2048])
after gather layer_id 0 8 torch.Size([2048, 2048])
after gather layer_id 0 9 torch.Size([2048])
after gather layer_id 0 10 torch.Size([2048, 2048])
after gather layer_id 0 11 torch.Size([2048])
after gather layer_id 0 12 torch.Size([2048, 2048])
after gather layer_id 0 13 torch.Size([2048])
after gather layer_id 0 14 torch.Size([2048, 2048])
after gather layer_id 0 15 torch.Size([2048])
By adding the following lines for further details:
https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/hybrid_engine.py#L234-L238
else:
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
rank = dist.get_rank(group=self.mp_group)
non_active_layers = get_inactive_params(self.all_layers_params)
if rank == 0:
print("nonactive layers", len(non_active_layers))
for lay_id, lay in enumerate(self.all_layers_params):
print("all layers", lay_id, hasattr(lay, 'ds_id'), lay.ds_status == ZeroParamStatus.NOT_AVAILABLE, lay.ds_status)
non_active_lora_params = get_inactive_params(self.all_lora_params)
if rank == 0:
print("nonactive lora layers", len(non_active_lora_params))
for lay_id, lay in enumerate(self.all_lora_params):
print("lora layers", lay_id, hasattr(lay, 'ds_id'), lay.ds_status == ZeroParamStatus.NOT_AVAILABLE, lay.ds_status)
non_active_layers.extend(non_active_lora_params)
It seems that the 0-shape parameters are marked as "ds_status == ZeroParamStatus.INFLIGHT" before calling "GatheredParameters":
[2023-04-17 15:33:56,759] [INFO] [loss_scaler.py:181:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 32768, reducing to 16384
epoch: 0|step: 2|ppo_ep: 1|act_loss: nan|cri_loss: nan|unsuper_loss: 0.0
average reward score: 3.267578125
-------------------------------------------------------------------------------------
|E2E latency=17.17s |Gather latency=0.46s (2.70%) |Generate time=7.04s (41.02%) |Training time=6.82s (39.71%) |Others=3.31 (19.27%)|CurSamplesPerSec=0.93 |AvgSamplesPerSec=0.60
nonactive layers 651
all layers 0 True False ZeroParamStatus.INFLIGHT
all layers 1 True False ZeroParamStatus.INFLIGHT
all layers 2 True False ZeroParamStatus.INFLIGHT
all layers 3 True False ZeroParamStatus.INFLIGHT
all layers 4 True False ZeroParamStatus.INFLIGHT
all layers 5 True False ZeroParamStatus.INFLIGHT
all layers 6 True False ZeroParamStatus.INFLIGHT
all layers 7 True False ZeroParamStatus.INFLIGHT
all layers 8 True False ZeroParamStatus.INFLIGHT
all layers 9 True False ZeroParamStatus.INFLIGHT
all layers 10 True False ZeroParamStatus.INFLIGHT
all layers 11 True False ZeroParamStatus.INFLIGHT
all layers 12 True False ZeroParamStatus.INFLIGHT
all layers 13 True False ZeroParamStatus.INFLIGHT
all layers 14 True False ZeroParamStatus.INFLIGHT
all layers 15 True False ZeroParamStatus.INFLIGHT
all layers 16 True False ZeroParamStatus.INFLIGHT
all layers 17 True False ZeroParamStatus.INFLIGHT
all layers 18 True False ZeroParamStatus.INFLIGHT
all layers 19 True False ZeroParamStatus.INFLIGHT
all layers 20 True False ZeroParamStatus.INFLIGHT
all layers 21 True True ZeroParamStatus.NOT_AVAILABLE
all layers 22 True True ZeroParamStatus.NOT_AVAILABLE
all layers 23 True True ZeroParamStatus.NOT_AVAILABLE
all layers 24 True True ZeroParamStatus.NOT_AVAILABLE
all layers 25 True True ZeroParamStatus.NOT_AVAILABLE
all layers 26 True True ZeroParamStatus.NOT_AVAILABLE
all layers 27 True True ZeroParamStatus.NOT_AVAILABLE
all layers 28 True False ZeroParamStatus.INFLIGHT
all layers 29 True False ZeroParamStatus.INFLIGHT
all layers 30 True True ZeroParamStatus.NOT_AVAILABLE
all layers 31 True True ZeroParamStatus.NOT_AVAILABLE
<snip>
nonactive lora layers 280
lora layers 0 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 1 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 2 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 3 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 4 True False ZeroParamStatus.INFLIGHT
lora layers 5 True False ZeroParamStatus.INFLIGHT
lora layers 6 True False ZeroParamStatus.INFLIGHT
lora layers 7 True False ZeroParamStatus.INFLIGHT
lora layers 8 True False ZeroParamStatus.INFLIGHT
lora layers 9 True False ZeroParamStatus.INFLIGHT
lora layers 10 True False ZeroParamStatus.INFLIGHT
lora layers 11 True False ZeroParamStatus.INFLIGHT
lora layers 12 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 13 True True ZeroParamStatus.NOT_AVAILABLE
I think those parameters are marked as INFLIGHT because they have been prefetched. Adding some more debugging lines to print the stack at the point where the status is set to INFLIGHT:
https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/zero/partition_parameters.py#L873-L885
def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = True) -> AllGatherCoalescedHandle:
# fetches from nvme if the partition is not available and in nvme
self._ensure_availability_of_partitioned_params(params)
if self.world_size == 1:
return _no_gather_coalesced(params)
#for param in params:
for p_id, param in enumerate(params):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(param.ds_summary())
param.ds_status = ZeroParamStatus.INFLIGHT
if dist.get_rank() == 0:
print(p_id, "INFLIGHT2")
if p_id > 20:
print(traceback.print_stack(file=sys.stdout))
I can see those layers are set to INFLIGHT here:
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
main()
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 452, in main
actor_loss, critic_loss = trainer.train_rlhf(exp_data)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 180, in train_rlhf
value = self.critic_model.forward_value(**batch,
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/model/reward_model.py", line 125, in forward_value
transformer_outputs = self.rwtranrsformer(
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 759, in forward
decoder_outputs = self.decoder(
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 665, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/path/site-packages/torch/utils/checkpoint.py", line 235, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/path/site-packages/torch/utils/checkpoint.py", line 96, in forward
outputs = run_function(*args)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 661, in custom_forward
return module(*inputs, output_attentions, None)
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 337, in forward
hidden_states = self.activation_fn(hidden_states)
File "/path/site-packages/torch/nn/modules/module.py", line 1137, in _call_impl
result = hook(self, input)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 366, in _pre_forward_module_hook
self.pre_sub_module_forward_function(module)
File "/path/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 478, in pre_sub_module_forward_function
param_coordinator.fetch_sub_module(sub_module)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 333, in fetch_sub_module
self.__all_gather_params(params_to_prefetch)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 381, in __all_gather_params
handle = partitioned_params[0].all_gather_coalesced(partitioned_params)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 878, in all_gather_coalesced
print(traceback.print_stack(file=sys.stdout))
It seems that the layers are being prefetched during the call to the critic model forward pass:
https://github.com/microsoft/DeepSpeedExamples/blob/2aa7a31b8fdcb34b8ccdc554021a1f5789752ab3/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L174
They are still in INFLIGHT
status when trying to generate a sample. The get_inactive_params
function then only include params marked as NOT_AVAILABLE
:
https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/utils.py#L972-L975
Later, GatheredParameters
may only consider params whose state is NOT_AVAILABLE:
https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/zero/partition_parameters.py#L1058
Assuming that diagnosis is correct, I'm not sure what the recommended fix would be. Should get_inactive_params
include INFLIGHT
params?
A second question that came up while looking at this... it seems like the if
conditions here might always be true:
https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/hybrid_engine.py#L123
https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/hybrid_engine.py#L136
Should it be lora_param
instead of lora_params
? Maybe change this to:
if len(lora_param) == 3:
Assuming that diagnosis is correct, I'm not sure what the recommended fix would be. Should
get_inactive_params
includeINFLIGHT
params?
@adammoody, thanks for the detailed analysis of this bug. To answer your question, no, INFLIGHT
params should not be gathered again. Param gathering is asynchronous for performance reasons, and INFLIGHT
params are part of ongoing gather operations.
This problem occurs because in RLHF we are context switching 5 models in a rank to share GPU memory. And so, the recommended solution here is to ensure that there are no INFLIGHT
params when context-switching out a model. This will guarantee that the actor model params are all NOT_AVAILABLE
before the next generate. The empty_partition_cache()
can be used for this purpose. I will share a PR asap.
@adammoody, can you please try this PR? Thanks!
Thanks for the explanation and quick reply, @tjruwase . Unfortunately, I'm still hitting the same problem with this PR.
The problematic params seem to be from the first few layers of the actor_model, which have been prefetched due to a forward step of the critic_model. I thought maybe we could move those empty calls to the end to try to clear any INFLIGHT actor params that the critic started to prefetch:
self.actor_model.backward(actor_loss)
self.actor_model.step()
#self.actor_model.empty_partition_cache()
value = self.critic_model.forward_value(**batch,
return_value_only=True,
use_cache=False)[:, :-1]
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
start:],
returns, action_mask[:, start:])
self.critic_model.backward(critic_loss)
self.critic_model.step()
#self.critic_model.empty_partition_cache()
self.actor_model.empty_partition_cache()
self.critic_model.empty_partition_cache()
However, with that change I get the following error on the self.actor_model.empty_partition_cache()
call:
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 358, in release_and_reset_all
raise RuntimeError(f"param {param.ds_summary()} still in flight")
RuntimeError: param {'id': 0, 'status': 'INFLIGHT', 'numel': 102957056, 'ds_numel': 102957056, 'shape': (50272, 2048), 'ds_shape': (50272, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': set()} still in flight
The problematic params seem to be from the first few layers of the actor_model, which have been prefetched due to a forward step of the critic_model. I thought maybe we could move those empty calls to the end to try to clear any INFLIGHT actor params that the critic started to prefetch:
This is confusing to me. Each model has independent prefetchers, so the critic_model should not affect the actor_model. Can you share a stack trace that you get with the PR?
@adammoody, by the way, I was not able to repro your error on my 4xV100-16GB setup. This makes it harder to resolve.
However, with that change I get the following error on the
self.actor_model.empty_partition_cache()
call:File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 358, in release_and_reset_all raise RuntimeError(f"param {param.ds_summary()} still in flight") RuntimeError: param {'id': 0, 'status': 'INFLIGHT', 'numel': 102957056, 'ds_numel': 102957056, 'shape': (50272, 2048), 'ds
By the way, this is a separate bug that needs to be addressed.
Should it be
lora_param
instead oflora_params
? Maybe change this to:if len(lora_param) == 3:
I think you have found a bug here. Do you mind opening a separate ticket for this?
Should it be
lora_param
instead oflora_params
? Maybe change this to:if len(lora_param) == 3:
I think you have found a bug here. Do you mind opening a separate ticket for this?
Sure. I'll post that one to the main DeepSpeed repo.
As another clue, it seems like the following changes work around the problem. I defined a "wait on inflight" function in deepspeed/runtime/zero/partitioned_param_coordinator.py
:
@instrument_w_nvtx
@torch.no_grad()
def wait_on_inflight_params(self, current_submodule):
params = frozenset(iter_params(current_submodule, recurse=True))
for param in params:
if param in self.__inflight_param_registry:
print(param.ds_summary())
with get_accelerator().stream(self.__allgather_stream):
while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query():
self.__ongoing_fetch_events.popleft()
if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
self.__ongoing_fetch_events.popleft().synchronize()
self.__inflight_param_registry.pop(param).wait()
event = get_accelerator().Event()
event.record()
self.__ongoing_fetch_events.append(event)
assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
get_accelerator().current_stream().wait_stream(self.__allgather_stream)
I then call that from partition_all_parameters
in deepspeed/runtime/zero/parameter_offload.py
, which is called from empty_partition_cache
:
def partition_all_parameters(self):
"""Partitioning Parameters that were not partitioned usually if parameters
of modules whose input parameters do not require grad computation do not
trigger post call and will therefore will remain unpartitioned"""
self.get_param_coordinator(training=self.module.training).wait_on_inflight_params(self.module)
self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module)
And then I call empty_partition_cache
on both models after training on both:
### process the new outputs
batch = {'input_ids': seq, "attention_mask": attention_mask}
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
inputs['input_ids'][:, 1:])
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
log_probs[:, start:], advantages,
action_mask[:, start:])
self.actor_model.backward(actor_loss)
self.actor_model.step()
value = self.critic_model.forward_value(**batch,
return_value_only=True,
use_cache=False)[:, :-1]
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
start:],
returns, action_mask[:, start:])
self.critic_model.backward(critic_loss)
self.critic_model.step()
# call empty_partition_cache after stepping both actor and critic models
self.actor_model.empty_partition_cache()
self.critic_model.empty_partition_cache()
If I drop the recurse=True
option to change:
params = frozenset(iter_params(current_submodule, recurse=True))
to:
params = frozenset(iter_params(current_submodule))
then I still get this error:
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 400, in release_and_reset_all
raise RuntimeError(f"param {param.ds_summary()} still in flight")
RuntimeError: param {'id': 0, 'status': 'INFLIGHT', 'numel': 102957056, 'ds_numel': 102957056, 'shape': (50272, 2048), 'ds_shape': (50272, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': set()} still in flight
It seems that it misses the necessary parameters without the recursion.
And if I keep recurse=True
, but call empty_partition_cache
after each train step individually rather than at the end of both steps like so:
### process the new outputs
batch = {'input_ids': seq, "attention_mask": attention_mask}
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
inputs['input_ids'][:, 1:])
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
log_probs[:, start:], advantages,
action_mask[:, start:])
self.actor_model.backward(actor_loss)
self.actor_model.step()
self.actor_model.empty_partition_cache()
value = self.critic_model.forward_value(**batch,
return_value_only=True,
use_cache=False)[:, :-1]
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
start:],
returns, action_mask[:, start:])
self.critic_model.backward(critic_loss)
self.critic_model.step()
self.critic_model.empty_partition_cache()
#self.actor_model.empty_partition_cache()
#self.critic_model.empty_partition_cache()
I still get the original error noted at the top of the issue:
weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
RuntimeError: The size of tensor a (8192) must match the size of tensor b (2048) at non-singleton dimension 1
In summary, it seems that I need to do all three of these:
- call
empty_partition_cache
after both steps - wait on any inflight params
- call
iter_params(..., recurse=True)
when getting the parameter list
Thanks for sharing these details. I agree that empty_partition_cache
needs a wait_on_inflight_params
logic like you discovered. However, I would like to take a step back to understand a few things.
First, empty_partition_cache
should guarantee that all params are NOT_AVAILABLE
. But you mentioned that you were hitting the original problem with my PR which calls empty_partition_cache
after actor_model.step()
. I don't understand how that can possible. So, could you please share the stack trace of applying my PR?
The error message and stack trace when using the changes in the PR are the same as the original error report. Right, I haven't given up on figuring out this problem either. I have some more ideas to try to debug things. I'll keep posting updates.
@tjruwase , I still haven't cracked it, but here are some more clues...
The first problematic layer corresponds to the vocab embedding layer of the actor model. I did verify that layer actually belongs to the actor model, and that it is not shared with the critic model or any other model.
The stack trace for the prefetch of that layer is shown below. The line numbers will vary because I've added lots of debug statements.
File "/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
main()
File "/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 452, in main
actor_loss, critic_loss = trainer.train_rlhf(exp_data)
File "/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 191, in train_rlhf
value = self.critic_model.forward_value(**batch,
File "/DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/model/reward_model.py", line 125, in forward_value
transformer_outputs = self.rwtranrsformer(
File "/path/python3.9/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 759, in forward
decoder_outputs = self.decoder(
File "/path/python3.9/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 665, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/path/python3.9/site-packages/torch/utils/checkpoint.py", line 235, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/path/python3.9/site-packages/torch/utils/checkpoint.py", line 96, in forward
outputs = run_function(*args)
File "/path/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 661, in custom_forward
return module(*inputs, output_attentions, None)
File "/path/python3.9/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/path/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 337, in forward
hidden_states = self.activation_fn(hidden_states)
File "/path/python3.9/site-packages/torch/nn/modules/module.py", line 1137, in _call_impl
result = hook(self, input)
File "/path/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 378, in _pre_forward_module_hook
self.pre_sub_module_forward_function(module)
File "/path/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/path/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 492, in pre_sub_module_forward_function
param_coordinator.fetch_sub_module(sub_module)
File "/path/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/path/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/path/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 368, in fetch_sub_module
print(traceback.print_stack(file=sys.stdout))
Note that this happens from a param_coordinator.fetch_sub_module(sub_module)
call.
That trace was kicked off by a call to:
value = self.critic_model.forward_value()
The ds_summary
output for that the param being prefetched is:
-prefetch: {'id': 0, 'status': 'NOT_AVAILABLE', 'numel': 0, 'ds_numel': 102957056, 'shape': (0,), 'ds_shape': (50272, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': set()}
I did verify that this layer belongs to the actor model by matching its Python id(param)
value and similarly verified that it's not shared with the critic model.
I can see that it fails for me on what I think is the third training step. This appears to be the first step where it has completed a trace and thus has enabled prefetching for the model.
def fetch_sub_module(self, current_submodule: Module) -> None:
<snip>
# kick off parameter prefetches for upcoming modules
# don't prefetch if we dont have a completed model trace
if self.is_complete_trace():
At that point self._param_ids == 674
, while it shows 0
for the two previous steps.
Since there seems to be some "model mixing" in this case, one area that caught my eye is the global FWD_MODULE_STACK
in deepspeed/runtime/zero/parameter_offload.py
.
@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)
if dist.get_rank() == 0:
print("FWD_MODULE_STACK length", len(FWD_MODULE_STACK), "id(module)", id(sub_module))
param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module)
You can see I've added a print above. The vocab embedding layer is the 8th or 9th element at the time the problem occurs. I'd have to double check which if you need an exact value.
That list is initialized with a base model.
def setup_zero_stage3_hooks(self):
self.hierarchy = 0
#reset step if in inference mode
@instrument_w_nvtx
def _end_of_forward_hook(module, *args):
if not torch._C.is_grad_enabled():
self.get_param_coordinator(training=False).reset_step()
#likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module)
self.module.register_forward_hook(_end_of_forward_hook)
# Add top module to stack trace
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(self.module)
if dist.get_rank() == 0:
print("FWD_MODULE_STACK length", len(FWD_MODULE_STACK), "id(module)", id(self.module))
print(str(self.module))
In this case, I can see that the four base models correspond to the first four elements of that list. I'm not sure what 4 or so modules are stored as the elements in between the base models and the vocab embed layer at the point where I see the problem.
I still haven't tracked down why invoking a function on the critic model could end up fetching params for the actor model, but I wondered if there might be some linkage here.
@adammoody, kudos on the intensive debugging. I think I know what might be wrong, but I need your help to confirm. I have updated my PR with some asserts to verify that empty_partition_cache()
is behaving as expected. Can you please try the PR again?
I added those new changes by hand, so my source file line numbers will be different. I have been editing DeepSpeed files in place within my python environment, so it takes some effort to set up a clean environment at this point. Anyway, if you trust that, I hit the tensor dimension mismatch here:
https://github.com/microsoft/DeepSpeedExamples/blob/ce049bee82bd4594209beb2bc0676a44af2b5758/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L77
6: 0: Traceback (most recent call last):
6: 0: File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
6: 0: main()
6: 0: File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 438, in main
6: 0: out = trainer.generate_experience(prompts)
6: 0: File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 106, in generate_experience
6: 0: seq = self._generate_sequence(prompts)
6: 0: File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 79, in _generate_sequence
6: 0: seq = self.actor_model.module.generate(prompts,
6: 0: File "/path/python3.9/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
6: 0: self.fuse_lora_weight()
6: 0: File "/path/python3.9/site-packages/deepspeed/runtime/hybrid_engine.py", line 139, in fuse_lora_weight
6: 0: self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
6: 0: File "/path/python3.9/site-packages/deepspeed/runtime/hybrid_engine.py", line 128, in _fuse_lora
6: 0: weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
6: 0: RuntimeError: The size of tensor a (8192) must match the size of tensor b (2048) at non-singleton dimension 1
If I add a second call to assert_empty_partition_cache
for the actor_model immediately after the check for the critic model:
### process the new outputs
batch = {'input_ids': seq, "attention_mask": attention_mask}
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
inputs['input_ids'][:, 1:])
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
log_probs[:, start:], advantages,
action_mask[:, start:])
self.actor_model.backward(actor_loss)
self.actor_model.step()
self.actor_model.empty_partition_cache()
assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf step')
value = self.critic_model.forward_value(**batch,
return_value_only=True,
use_cache=False)[:, :-1]
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
start:],
returns, action_mask[:, start:])
self.critic_model.backward(critic_loss)
self.critic_model.step()
self.critic_model.empty_partition_cache()
assert_empty_partition_cache(self.critic_model, 'critic_model after rlhf step')
--> assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf critic step')
then the assertion triggers:
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 204, in train_rlhf
actor_loss, critic_loss = trainer.train_rlhf(exp_data)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 204, in train_rlhf
assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf critic step')
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 290, in assert_empty_partition_cache
assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf critic step')
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 290, in assert_empty_partition_cache
assert len(avail_or_inflight_params) == 0, \
AssertionError: actor_model after rlhf critic step empty_partition_cache failed to evict all params: remaining = [0, 1, 2, 3, 387, 388, 4, 5, 389, 390, 6, 7, 391, 392, 8, 9, 393, 394, 10, 11, 12, 16, 17]
Thanks for sharing these updates. Adding the second assert for the actor model cache is a really good idea. It is mystery why it fails. This supports your suspicion of a leakage between the parameter partitioning of actor and critic model.
Can you confirm that your critic model is 1.3b or 350m?
Also, can you try dropping --enable_hybrid_engine
from your command line?
Yes, I'm actually using a 350m model for the critic. I had a cut-and-paste typo in the path name when I wrote out the checkpoint, so the path suggests it's a 1.3b param model, but it is really 350m.
I tried dropping the --enable_hybrid_engine
option. I do still hit the assertion. I think it triggered one step earlier than before.
AssertionError: actor_model after rlhf critic step empty_partition_cache failed to evict all params: remaining = [0, 1, 2, 3, 387, 388, 4, 5, 389, 390, 6, 7, 391, 392, 8, 9, 393, 394, 10, 11, 12, 16, 17]
Here are some other work arounds that I found earlier but didn't list yet:
- dropping from stage 3 to stage 2 works
--actor_zero_stage 2
--critic_zero_stage 2
- hard-coding the prefetch buffer to be 0-size works (stage 3 with prefetch disabled) https://github.com/microsoft/DeepSpeed/blob/d92539509b1e9a6178cfdb921d5080e76f690bce/deepspeed/runtime/zero/parameter_offload.py#L242
#self._prefetch_bucket_sz = int(prefetch_bucket_size)
self._prefetch_bucket_sz = 0
Thanks for the update.
- Hitting the assertion is good since it stops at the earliest violation of the invariant of
empty_partition_cache()
. - It is good to know that
--enable_hybrid_engine
is not the cause. - Yes, stage 2 should not cause this since it is stage 3 specific.
- Disabling prefetching (can be done through ds_config) should avoid this problem.
@adammoody, FYI I think this DeepSpeed PR from my colleague @HeyangQin might be relevant here. Please give him a bit more time to get it ready.
@tjruwase , I think I found the cause.
I believe the problem is that all four models share the same ReLU module object. Each model registers a forward hook on that module in setup_zero_stage3_hooks()
. When invoking the forward pass on the ReLU module from the critic model, the hook from the actor model is invoked, which leads to the prefetch of the actor layers.
I found this by adding the following code in deepspeed/runtime/zero/parameter_offload.py
to print object addresses of all child modules of each model:
def print_children(module, indent):
for name, m in module.named_children():
spaces = " " * indent
print(spaces, name, id(m))
print_children(m, indent + 2)
<snip>
def setup_zero_stage3_hooks(self):
self.hierarchy = 0
#reset step if in inference mode
@instrument_w_nvtx
def _end_of_forward_hook(module, *args):
if not torch._C.is_grad_enabled():
self.get_param_coordinator(training=False).reset_step()
#likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module)
self.module.register_forward_hook(_end_of_forward_hook)
# Add top module to stack trace
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(self.module)
if dist.get_rank() == 0:
print("FWD_MODULE_STACK SETUP length", len(FWD_MODULE_STACK), "id(module)", id(self.module), type(self.module))
print(str(self.module))
for p_id, param in enumerate(iter_params(self.module, recurse=True)):
key = id(param) if hasattr(param, 'ds_id') else id(param.ds_param_alias)
print(" ", p_id, id(param), type(param), key)
print_children(self.module, 2)
With that, I get the following example output for the actor and reward models. You can see that the activation_fn
module has the same address for all layers in all models.
FWD_MODULE_STACK SETUP length 1 id(module) 35188069303104 <class 'transformers.models.opt.modeling_opt.OPTForCausalLM'>
3: 0: model 35188109644992
3: 0: decoder 35188109647776
3: 0: embed_tokens 35188109647248
3: 0: embed_positions 35188109647440
3: 0: layers 35188109646144
3: 0: 0 35188109647392
3: 0: self_attn 35188109644800
3: 0: k_proj 35188109646480
3: 0: lora_dropout 35188109646576
3: 0: v_proj 35188109644320
3: 0: lora_dropout 35188092098880
3: 0: q_proj 35188109644848
3: 0: lora_dropout 35188092098832
3: 0: out_proj 35188109647296
3: 0: lora_dropout 35188092096912
3: 0: --> activation_fn 35186791027472 <--- same address for all layers, in each model
3: 0: self_attn_layer_norm 35188109645040
3: 0: fc1 35188109645616
3: 0: lora_dropout 35188092098736
3: 0: fc2 35188109644656
3: 0: lora_dropout 35188092100560
3: 0: final_layer_norm 35188108844336
3: 0: 1 35188108843472
3: 0: self_attn 35188108843328
3: 0: k_proj 35188109643984
3: 0: lora_dropout 35188107776880
3: 0: v_proj 35188108844432
3: 0: lora_dropout 35188107776592
3: 0: q_proj 35188108842800
3: 0: lora_dropout 35188107777408
3: 0: out_proj 35188108841264
3: 0: lora_dropout 35188107777456
3: 0: --> activation_fn 35186791027472 <--- same address for all layers, in each model
3: 0: self_attn_layer_norm 35188108841456
3: 0: fc1 35188108842512
3: 0: lora_dropout 35188107777696
3: 0: fc2 35188108841888
3: 0: lora_dropout 35188107776160
3: 0: final_layer_norm 35188104685504
<snip>
FWD_MODULE_STACK SETUP length 4 id(module) 35188109548272 <class 'utils.model.reward_model.RewardModel'>
3: 0: v_head 35188109667440
3: 0: rwtranrsformer 35188109548224
3: 0: decoder 35188109548320
3: 0: embed_tokens 35188109548416
3: 0: embed_positions 35188109548368
3: 0: project_out 35188109548512
3: 0: project_in 35188109548560
3: 0: layers 35188109548656
3: 0: 0 35188109548608
3: 0: self_attn 35188109548704
3: 0: k_proj 35188109548800
3: 0: v_proj 35188109548848
3: 0: q_proj 35188109548992
3: 0: out_proj 35188109549040
3: 0: --> activation_fn 35186791027472 <--- same address for all layers, in each model
3: 0: self_attn_layer_norm 35188109548752
3: 0: fc1 35188109549136
3: 0: fc2 35188109549184
3: 0: final_layer_norm 35188109549232
3: 0: 1 35188109549328
3: 0: self_attn 35188109549376
3: 0: k_proj 35188109549472
3: 0: v_proj 35188109549520
3: 0: q_proj 35188109664320
3: 0: out_proj 35188109664368
3: 0: --> activation_fn 35186791027472 <--- same address for all layers, in each model
3: 0: self_attn_layer_norm 35188109549424
3: 0: fc1 35188109547552
3: 0: fc2 35188109547792
3: 0: final_layer_norm3: 0: 35188109547840
As a test, I then found that I could work around the problem by modifying the OPT model to instantiate a unique ReLU object for each layer in transformers/models/opt/modeling_opt.py
:
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
--> #self.activation_fn = ACT2FN[config.activation_function]
--> self.activation_fn = nn.ReLU()
@tjruwase , I think I found the cause.
I believe the problem is that all four models share the same ReLU module object. Each model registers a forward hook on that module in
setup_zero_stage3_hooks()
. When invoking the forward pass on the ReLU module from the critic model, the hook from the actor model is invoked, which leads to the prefetch of the actor layers.
Amazing debugging, @adammoody. Truly outstanding!
@stas00, FYI. It seems ReLU objects are shared across models in the same transformer process. Do you have context for this behavior?
yes, it's the same object, it's like a cache:
creation: https://github.com/huggingface/transformers/blob/b6865b9befad33f99adee0a6ef6361f72fcc8b42/src/transformers/activations.py#L206-L233
use: https://github.com/huggingface/transformers/blob/b6865b9befad33f99adee0a6ef6361f72fcc8b42/src/transformers/models/opt/modeling_opt.py#L288
The paradigm is shifting. Clearly there was no need to create a new object before because deepspeed won't support more than one model. And there is no issue with reusing the same object with multiple models outside of deepspeed world.
Probably should file a feature request to create these on the fly, rather the pre-create. So that each instance will be unique.
There are quite a few changes that need to be made to support multiple deepspeed models paradigm.
Some possible workarounds:
- A quick hack would be to overload
transformers.activations.ACT2FN
getter to clone the object when it's looked up. - Perhaps Deepspeed should detect if it's installing a hook into an object that it's already hooked and may be somehow clone it in place? or at the very least assert if that's the case
raise RuntimeError(f"still have inflight params "f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}")
this error is still reported when I running step3 use bloomz + zero3.
I encountered this bug after the previous bug (https://github.com/microsoft/DeepSpeed/issues/3528) was solved. @HeyangQin
raise RuntimeError(f"still have inflight params "f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}")
this error is still reported when I running step3 use bloomz + zero3.I encountered this bug after the previous bug (microsoft/DeepSpeed#3528) was solved. @HeyangQin
I have the same issue. Have you resolved the problem?
raise RuntimeError(f"still have inflight params "f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}")
this error is still reported when I running step3 use bloomz + zero3.I encountered this bug after the previous bug (microsoft/DeepSpeed#3528) was solved. @HeyangQin
Encountering the same error here. The issue persists even after updating DeepSpeed and PyTorch Lightning to the latest versions.