gpt-neox icon indicating copy to clipboard operation
gpt-neox copied to clipboard

The plot got from muP coord_check seems not horizontal, which may indicates there exits a bug in the muP implementation?

Open BaoYu0721 opened this issue 1 year ago • 11 comments

Bug Discription & To Reproduce The source code is from current main branch, and follow the instructions in the until this step: image I encounter an error like this:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /mnt/cache/baoyu/gpt-neox/ in <module>                            │
│                                                                              │
│   24 │   neox_args.configure_distributed_args()                              │
│   25 │   neox_args.build_tokenizer()  # tokenizer needs to be build in train │
│   26 │   neox_args.initialize_tensorboard_writer()  # is initialized if tens │
│ ❱ 27 │   pretrain(neox_args=neox_args)                                       │
│   28                                                                         │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/ in pretrain               │
│                                                                              │
│   208 │   timers("train/valid/test data iterators").stop()                   │
│   209 │                                                                      │
│   210 │   if neox_args.use_mup and neox_args.coord_check:                    │
│ ❱ 211 │   │   mup_coord_check(neox_args, timers, lr_scheduler, train_data_it │
│   212 │                                                                      │
│   213 │   # Print setup timing.                                              │
│   214 │   print_rank_0("done with setups ...")                               │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/ in mup_coord_check        │
│                                                                              │
│   151 │   │   models[hidden_size] = lazy_model(hidden_size)                  │
│   152 │                                                                      │
│   153 │   neox_args.use_mup = True                                           │
│ ❱ 154 │   df_up = get_coord_data(                                            │
│   155 │   │   neox_args, timers, lr_scheduler, models, train_data_iterator,  │
│   156 │   )                                                                  │
│   157 │   neox_args.use_mup = False                                          │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/ in get_coord_data   │
│                                                                              │
│   204 │   elif optimizer is None:                                            │
│   205 │   │   raise ValueError("optimizer should be sgd|adam|adamw or a cust │
│   206 │                                                                      │
│ ❱ 207 │   data = _get_coord_data(                                            │
│   208 │   │   neox_args, timers, lr_scheduler, models, dataloader, optcls, * │
│   209 │   )                                                                  │
│   210 │   data["optimizer"] = optimizer                                      │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/ in _get_coord_data   │
│                                                                              │
│    66 │   │   │   │   │   )                                                  │
│    67 │   │   │   │                                                          │
│    68 │   │   │   │   # train for a step                                     │
│ ❱  69 │   │   │   │   loss_dict, skipped_iter = train_step(                  │
│    70 │   │   │   │   │   neox_args=neox_args,                               │
│    71 │   │   │   │   │   timers=timers,                                     │
│    72 │   │   │   │   │   data_iterator=dataloader,                          │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/ in train_step             │
│                                                                              │
│   692 │                                                                      │
│   693 │   # Pipeline parallelism schedules forward/backward/step             │
│   694 │   if neox_args.is_pipe_parallel:                                     │
│ ❱ 695 │   │   reduced_loss = train_step_pipe(                                │
│   696 │   │   │   neox_args=neox_args, timers=timers, model=model, data_iter │
│   697 │   │   )                                                              │
│   698 │   else:                                                              │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/ in train_step_pipe        │
│                                                                              │
│   742 │   """Single training step with DeepSpeed's pipeline parallel engine. │
│   743 │                                                                      │
│   744 │   assert neox_args.deepspeed                                         │
│ ❱ 745 │   loss = model.train_batch(data_iter=data_iterator)                  │
│   746 │   loss_dict = {"lm_loss": loss}                                      │
│   747 │   # Don't break Megatron's timers because we changed code paths.     │
│   748 │   for t in [                                                         │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/ in train_batch                                            │
│                                                                              │
│    333 │   │   sched = schedule.TrainSchedule(micro_batches=self.micro_batch │
│    334 │   │   │   │   │   │   │   │   │      stages=self.num_stages,        │
│    335 │   │   │   │   │   │   │   │   │      stage_id=self.stage_id)        │
│ ❱  336 │   │   self._exec_schedule(sched)                                    │
│    337 │   │   self.agg_train_loss = self._aggregate_total_loss()            │
│    338 │   │                                                                 │
│    339 │   │   self.timers('train_batch').stop()                             │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/ in _exec_schedule                                        │
│                                                                              │
│   1304 │   │   │   │                                                         │
│   1305 │   │   │   │   # Equivalent to: self._exec_forward_pass(buffer_id=0) │
│   1306 │   │   │   │   self._exec_instr = MethodType(self._INSTRUCTION_MAP[t │
│ ❱ 1307 │   │   │   │   self._exec_instr(**cmd.kwargs)                        │
│   1308                                                                       │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/ in _exec_forward_pass                                     │
│                                                                              │
│    624 │   │   # tensor changes across batches                               │
│    625 │   │   self._zero_grads(inputs)                                      │
│    626 │   │                                                                 │
│ ❱  627 │   │   outputs = super().forward(inputs)                             │
│    628 │   │                                                                 │
│    629 │   │   # Reset activation checkpointing buffers.                     │
│    630 │   │   # Need to call this between evaluation iterations             │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/utils/nv │
│ in wrapped_fn                                                       │
│                                                                              │
│   12 │                                                                       │
│   13 │   def wrapped_fn(*args, **kwargs):                                    │
│   14 │   │   get_accelerator().range_push(func.__qualname__)                 │
│ ❱ 15 │   │   ret_val = func(*args, **kwargs)                                 │
│   16 │   │   get_accelerator().range_pop()                                   │
│   17 │   │   return ret_val                                                  │
│   18                                                                         │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ in forward                                                    │
│                                                                              │
│   1728 │   │   if self.fp16_auto_cast():                                     │
│   1729 │   │   │   inputs = self._cast_inputs_half(inputs)                   │
│   1730 │   │                                                                 │
│ ❱ 1731 │   │   loss = self.module(*inputs, **kwargs)                         │
│   1732 │   │                                                                 │
│   1733 │   │   if self.zero_optimization_partition_weights():                │
│   1734 │   │   │   # Disable automated discovery of external parameters      │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/torch/nn/modules/m │
│ in _call_impl                                                  │
│                                                                              │
│   1209 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks)   │
│   1210 │   │   │   input = bw_hook.setup_input_hook(input)                   │
│   1211 │   │                                                                 │
│ ❱ 1212 │   │   result = forward_call(*input, **kwargs)                       │
│   1213 │   │   if _global_forward_hooks or self._forward_hooks:              │
│   1214 │   │   │   for hook in (*_global_forward_hooks.values(), *self._forw │
│   1215 │   │   │   │   hook_result = hook(self, input, result)               │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/ in forward                                                │
│                                                                              │
│   347 │   │   │   │   if self._is_checkpointable(funcs):                     │
│   348 │   │   │   │   │   x = self.activation_checkpoint_func(exec_range_fun │
│   349 │   │   │   │   else:                                                  │
│ ❱ 350 │   │   │   │   │   x = exec_range_func(start_idx, end_idx)(*x)        │
│   351 │   │   return x                                                       │
│   352 │                                                                      │
│   353 │   def _partition_layers(self, method='uniform'):                     │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/ in exec_func                                              │
│                                                                              │
│   324 │   │   │   │   │   │   else:                                          │
│   325 │   │   │   │   │   │   │   ds_utils.set_random_seed(new_seed)         │
│   326 │   │   │   │   │                                                      │
│ ❱ 327 │   │   │   │   │   inputs = layer(inputs)                             │
│   328 │   │   │   │   return inputs                                          │
│   329 │   │   │                                                              │
│   330 │   │   │   return exec_func                                           │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/torch/nn/modules/m │
│ in _call_impl                                                  │
│                                                                              │
│   1212 │   │   result = forward_call(*input, **kwargs)                       │
│   1213 │   │   if _global_forward_hooks or self._forward_hooks:              │
│   1214 │   │   │   for hook in (*_global_forward_hooks.values(), *self._forw │
│ ❱ 1215 │   │   │   │   hook_result = hook(self, input, result)               │
│   1216 │   │   │   │   if hook_result is not None:                           │
│   1217 │   │   │   │   │   result = hook_result                              │
│   1218                                                                       │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/ │
│ :161 in f                                                                    │
│                                                                              │
│   158 │   │   │   │   for i, out in enumerate(output):                       │
│   159 │   │   │   │   │   _ret = copy(ret)                                   │
│   160 │   │   │   │   │   _ret['module'] += f':out[{i}]'                     │
│ ❱ 161 │   │   │   │   │   get_stat(_ret, out, output_fdict)                  │
│   162 │   │   │   elif isinstance(output, dict):                             │
│   163 │   │   │   │   for name, out in output.items():                       │
│   164 │   │   │   │   │   _ret = copy(ret)                                   │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/ │
│ :145 in get_stat                                                             │
│                                                                              │
│   142 │   │   │   elif isinstance(x, torch.Tensor):                          │
│   143 │   │   │   │   _d = copy(d)                                           │
│   144 │   │   │   │   for fname, f in fdict.items():                         │
│ ❱ 145 │   │   │   │   │   _d[fname] = f(x).item()                            │
│   146 │   │   │   │   records.append(_d)                                     │
│   147 │   │   │   else:                                                      │
│   148 │   │   │   │   raise NotImplemented(f'Unexpected output type: {type(x │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/ │
│ :44 in <lambda>                                                              │
│                                                                              │
│    41                                                                        │
│    42 #: dict of provided functions for use in coord check                   │
│    43 FDICT = {                                                              │
│ ❱  44 │   'l1': lambda x: torch.abs(x).mean(),                               │
│    45 │   'l2': lambda x: (x**2).mean()**0.5,                                │
│    46 │   'mean': lambda x: x.mean(),                                        │
│    47 │   'std': lambda x: x.std(),                                          │
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a
floating point or complex dtype. Got: Bool

This is caused by passing a Bool Tensor into the get_stat of mup(maybe the attention mask), but the mup library cannot handle it. In addition, we will also encount an error which is caused by passing None to the get_stat.

In order to solve this problem temporarily, I modify the source code in the file in mup like this: image

This time, coord_check ran successfully, it outputs many jpgs, one for each GPU, jpg from different GPUs looks very similar, so I just show one jpg for each paramerization.

Standard Parameterization: coord_check_sp 1

muP Parameterization: coord_check_up 1

The result looks weird, the SP is more horizontal than muP, which is not expected.

Expected behavior An expected behavior should looks like the plots in the above link, in which muP is very horizontal, while SP blows up.

Proposed solution I check the code related to mup, but don't have a proposal yet, I will try to keep checking it. Maybe contributors in the issue( can give some comments? @nsarka @Quentin-Anthony @StellaAthena Thanks a lot!

Environment (please complete the following information):

  • GPUs: 8 V100 GPU on one node
  • Configs:
    python configs/2-7B.yml configs/local_setup.yml
    I add mup related configs into the configs/local_setup.yml, and keep completely identical to the instructions in

BaoYu0721 avatar May 28 '23 13:05 BaoYu0721

Thanks for raising this issue. It looks like you’re correct and we broke the implementation at some point.

One thing we really need to start doing (but haven’t been able to do due to manpower limitations) is build out a robust testing suite that verifies new major changes don’t break old features :S

StellaAthena avatar May 28 '23 14:05 StellaAthena

Thanks for your reply! I checkout to some other commits, such as the v2.0 release tag and earlier commit when deepspeed_main is merged into main (2b84f9af10eebdb82dfb956adc2cb54ba2f62344), and find the plots are similar to the discription above, maybe the bug is introduced even earlier?

BaoYu0721 avatar May 29 '23 09:05 BaoYu0721

I was looking into the muP implementation in gpt-neox to contrast it with the Megatron-LM setup and accidentally found this issue :)

I am thinking, could LR schedule be the cause of the problem? By design it overrides the LR values per group (hence overwrites muP changes), and so the way muP scaling was introduced in AnnealingLR() as rescaling by group["width_mult"] in step() here. But I couldn't find that this key was added neither inside mup optimisers nor in gpt-neox codebase, so I am not sure that width_mult rescaling is applied at all.

Also, width_mult rescaling can be applied only for Adam-like optimisers and matrix-like params (as here), while for SGD the rescaling is with different multipliers, and so should be taken into account.

However, neither I found whether AnnealingLR schedule is actually applied during the training, so that might well be that my comment isn't really relevant to the observed behaviour.

ofivite avatar Aug 02 '23 10:08 ofivite

I think you are right @ofivite, i don't think the implementation was ever correct since the learning rate wasn't correctly setup from the beginning. After a long and thorough debugging I managed to pass at least the 2 basic sanity checks for mup:

  1. at same width mup doesn't do anything (since all shapes are the same it should coincide with SP)
  2. all the rest being fixed, you only get better by going wider

Issues I have found are:

  • lr not scaled, added the width_mult key to the groups dicts (as stated above)
  • MuAdam being used instead of MuAdamW (maybe irrelevant, need to test more)
  • I have fixed all initializations to be normal with a fixed std. I'm not sure the method is correctly implemented or should be used at all for generic initialization functions, specially those that already scale the std by fan_in since as far as I can see that same std is then scaled again

marcobellagente93 avatar Aug 28 '23 13:08 marcobellagente93

Found another bug, neox_args.use_mup is set to false before initializing models, which also sets their use_mup attribute to False and therefore always ignore multipliers

marcobellagente93 avatar Aug 29 '23 07:08 marcobellagente93

And finally there seems to be a bug in the re-initialization of the output layer, after skipping that completely (it should be anyway in the flavour of Table 8) I'm getting these very nice and smooth horizontal lines

coord_check_up 0

marcobellagente93 avatar Aug 29 '23 09:08 marcobellagente93

@marcobellagente93 Oh yes, now it's indeed nicely flat curves, great ! :)

ofivite avatar Aug 31 '23 15:08 ofivite

I'll make a PR as soon as I can

marcobellagente93 avatar Aug 31 '23 15:08 marcobellagente93

Edit: email formatting did not work properly.

Hi, I’m interested to see the PR as well. When I originally made my PR, the curves looked as expected—flat for mup and blown up for SP.

The use_mup parameter is set to false so the weights never get initialized

Are you referring to it being false by default? When calculating the base shapes using mup, two models have to be instantiated, one with use_mup set to false and the other set to true. If I remember correctly, I set use_mup to false by default everywhere and enabled it only when mup was set to true in the config file. Then, when calculating base shapes it’s forced to false.

The learning rate is not scaled by width_mult

I saw some code being linked to in this thread that’s from the original mup repo from Microsoft. I had to bring in and modify a lot of it. I believe width_mult for the learning rate may have been set there.

nsarka avatar Aug 31 '23 16:08 nsarka

What I mean is that the main training loops with mup enabled does the following:

  1. set neox_args.use_mup to false
  2. initialize model
  3. set neox_args.use_mup back to true

but at step 1 all parameters get initialized with self.use_mup = neox_args.use_mup (which is false) and causes everything else to be wrong (multipliers not used, 1/d attention not used, ...)

marcobellagente93 avatar Aug 31 '23 17:08 marcobellagente93

This behavior is expected--the weights are reinitialized using

Or do you mean this function does not get called?

nsarka avatar Aug 31 '23 17:08 nsarka