llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

adding te Linear for fp8 support

Open vchiley opened this issue 2 years ago • 8 comments

ran composer train/train.py train/yamls/pretrain/mpt-3b.yaml also with model.fc_type=te and precision=amp_fp8 Result:

torch: throughput/device/tokens_per_sec: 23.7k
te: throughput/device/tokens_per_sec: 23.7k
te with fp8: throughput/device/tokens_per_sec: 29.4k

Note there does seem to be this error when activation ckpt is enabled when activation_checkpointing_reentrant: false. If we set activation_checkpointing_reentrant: true, then act ckpt works fine without amp_fp8; with amp_fp8 the issue still persists. (previously, circa summer 2022, activation_checkpointing_reentrant: true resulted in some difficulties which is why we set it to false; not sure if this is necessary still...) ActCkpt error will be added to the comments. (the error with amp_fp8 might be a composer impl of fp8 issue)

vchiley avatar Jun 02 '23 18:06 vchiley

activation_checkpointing_reentrant: false actckpt error without amp_fp8

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/llm-foundry/scripts/train/train.py:254 in <module>                                          │
│                                                                                                  │
│   251 │   │   yaml_cfg = om.load(f)                                                              │
│   252 │   cli_cfg = om.from_cli(args_list)                                                       │
│   253 │   cfg = om.merge(yaml_cfg, cli_cfg)                                                      │
│ ❱ 254 │   main(cfg)                                                                              │
│   255                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/scripts/train/train.py:243 in main                                              │
│                                                                                                  │
│   240 │   │   trainer.eval()                                                                     │
│   241 │                                                                                          │
│   242 │   print('Starting training...')                                                          │
│ ❱ 243 │   trainer.fit()                                                                          │
│   244 │                                                                                          │
│   245 │   print('Done.')                                                                         │
│   246                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1766 in fit       │
│                                                                                                  │
│   1763 │   │   │   self.state.scaler = ClosureGradScaler() if self._use_closures() else GradSca  │
│   1764 │   │                                                                                     │
│   1765 │   │   self.first_batch_complete = False                                                 │
│ ❱ 1766 │   │   self._train_loop()                                                                │
│   1767 │                                                                                         │
│   1768 │   def close(self):                                                                      │
│   1769 │   │   """Shutdown the trainer.                                                          │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1940 in           │
│ _train_loop                                                                                      │
│                                                                                                  │
│   1937 │   │   │   │   │   │   self.logger.log_metrics({'time/token': self.state.timestamp.toke  │
│   1938 │   │   │   │   │   │   self.logger.log_metrics({'time/token_in_epoch': self.state.times  │
│   1939 │   │   │   │   │                                                                         │
│ ❱ 1940 │   │   │   │   │   total_loss_dict = self._train_batch(use_grad_scaling)                 │
│   1941 │   │   │   │   │                                                                         │
│   1942 │   │   │   │   │   if use_grad_scaling:                                                  │
│   1943 │   │   │   │   │   │   self.state.scaler.update()                                        │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in           │
│ _train_batch                                                                                     │
│                                                                                                  │
│   2121 │   │   │   │   │   │   │   │   │   │   │   │      closure=lambda loss_dict=total_loss_d  │
│   2122 │   │   │   │   │   │   │   │   │   │   │   │      _train_microbatches(microbatches, los  │
│   2123 │   │   │   │   │   │   else:                                                             │
│ ❱ 2124 │   │   │   │   │   │   │   optimizer.step(closure=lambda **kwargs: self._train_microbat  │
│   2125 │   │   │   │   │   │   │   │   microbatches, total_loss_dict, **kwargs).item())          │
│   2126 │   │   │   │   else:                                                                     │
│   2127 │   │   │   │   │   self._train_microbatches(microbatches, total_loss_dict)               │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py:69 in wrapper                         │
│                                                                                                  │
│     66 │   │   │   │   instance = instance_ref()                                                 │
│     67 │   │   │   │   instance._step_count += 1                                                 │
│     68 │   │   │   │   wrapped = func.__get__(instance, cls)                                     │
│ ❱   69 │   │   │   │   return wrapped(*args, **kwargs)                                           │
│     70 │   │   │                                                                                 │
│     71 │   │   │   # Note that the returned function here is no longer a bound method,           │
│     72 │   │   │   # so attributes like `__func__` and `__self__` no longer exist.               │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/optim/optimizer.py:280 in wrapper                           │
│                                                                                                  │
│   277 │   │   │   │   │   │   │   raise RuntimeError(f"{func} must return None or a tuple of (   │
│   278 │   │   │   │   │   │   │   │   │   │   │      f"but got {result}.")                       │
│   279 │   │   │   │                                                                              │
│ ❱ 280 │   │   │   │   out = func(*args, **kwargs)                                                │
│   281 │   │   │   │   self._optimizer_step_code()                                                │
│   282 │   │   │   │                                                                              │
│   283 │   │   │   │   # call optimizer step post hooks                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/utils/_contextlib.py:115 in decorate_context                │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py:288  │
│ in step                                                                                          │
│                                                                                                  │
│   285 │   │   loss = None                                                                        │
│   286 │   │   if closure is not None:                                                            │
│   287 │   │   │   with torch.enable_grad():                                                      │
│ ❱ 288 │   │   │   │   loss = closure()                                                           │
│   289 │   │                                                                                      │
│   290 │   │   for group in self.param_groups:                                                    │
│   291 │   │   │   params_with_grad = []                                                          │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in <lambda>  │
│                                                                                                  │
│   2121 │   │   │   │   │   │   │   │   │   │   │   │      closure=lambda loss_dict=total_loss_d  │
│   2122 │   │   │   │   │   │   │   │   │   │   │   │      _train_microbatches(microbatches, los  │
│   2123 │   │   │   │   │   │   else:                                                             │
│ ❱ 2124 │   │   │   │   │   │   │   optimizer.step(closure=lambda **kwargs: self._train_microbat  │
│   2125 │   │   │   │   │   │   │   │   microbatches, total_loss_dict, **kwargs).item())          │
│   2126 │   │   │   │   else:                                                                     │
│   2127 │   │   │   │   │   self._train_microbatches(microbatches, total_loss_dict)               │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2222 in           │
│ _train_microbatches                                                                              │
│                                                                                                  │
│   2219 │   │   │                                                                                 │
│   2220 │   │   │   for microbatch_idx, self.state.batch in enumerate(microbatches):              │
│   2221 │   │   │   │   is_final_microbatch = microbatch_idx + 1 == len(microbatches)             │
│ ❱ 2222 │   │   │   │   microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_  │
│   2223 │   │   │   │                                                                             │
│   2224 │   │   │   │   # Aggregate each loss in microbatch_loss_dict into total_loss_dict        │
│   2225 │   │   │   │   for k, microbatch_loss in microbatch_loss_dict.items():                   │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2349 in           │
│ _train_microbatch                                                                                │
│                                                                                                  │
│   2346 │   │   │   else:                                                                         │
│   2347 │   │   │   │   # Scale loss based on the number of samples in the microbatch to maintai  │
│   2348 │   │   │   │   microbatch_loss.mul_(microbatch_num_samples / current_batch_size)         │
│ ❱ 2349 │   │   │   │   microbatch_loss.backward(create_graph=self._backwards_create_graph)       │
│   2350 │   │   │                                                                                 │
│   2351 │   │   │   self.engine.run_event(Event.AFTER_BACKWARD)                                   │
│   2352                                                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/_tensor.py:487 in backward                                  │
│                                                                                                  │
│    484 │   │   │   │   create_graph=create_graph,                                                │
│    485 │   │   │   │   inputs=inputs,                                                            │
│    486 │   │   │   )                                                                             │
│ ❱  487 │   │   torch.autograd.backward(                                                          │
│    488 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs                     │
│    489 │   │   )                                                                                 │
│    490                                                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/autograd/__init__.py:200 in backward                        │
│                                                                                                  │
│   197 │   # The reason we repeat same the comment below is that                                  │
│   198 │   # some Python versions print out the first line of a multi-line function               │
│   199 │   # calls in the traceback and some print out the last line                              │
│ ❱ 200 │   Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the bac   │
│   201 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                        │
│   202 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to ru   │
│   203                                                                                            │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/autograd/function.py:274 in apply                           │
│                                                                                                  │
│   271 │   │   │   │   │   │   │      "Function is not allowed. You should only implement one "   │
│   272 │   │   │   │   │   │   │      "of them.")                                                 │
│   273 │   │   user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn                    │
│ ❱ 274 │   │   return user_fn(self, *args)                                                        │
│   275 │                                                                                          │
│   276 │   def apply_jvp(self, *args):                                                            │
│   277 │   │   # _forward_cls is defined by derived class                                         │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:1865 in  │
│ backward                                                                                         │
│                                                                                                  │
│   1862 │   │   │   │   weight,                                                                   │
│   1863 │   │   │   │   weight_t_fp8,                                                             │
│   1864 │   │   │   │   fwd_scale_inverses,                                                       │
│ ❱ 1865 │   │   │   ) = ctx.saved_tensors                                                         │
│   1866 │   │   │                                                                                 │
│   1867 │   │   │   if ctx.ub_split_ag:                                                           │
│   1868 │   │   │   │   tp_world_size = get_distributed_world_size(ctx.tp_group)                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: !grad_accumulator_.expired() INTERNAL ASSERT FAILED at "../torch/csrc/autograd/saved_variable.cpp":226, please report a bug to PyTorch. No 
grad accumulator for a saved leaf

vchiley avatar Jun 02 '23 18:06 vchiley

actckpt error with amp_fp8 (with activation_checkpointing_reentrant: false or true)

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/llm-foundry/scripts/train/train.py:260 in <module>                                          │
│                                                                                                  │
│   257 │   │   yaml_cfg = om.load(f)                                                              │
│   258 │   cli_cfg = om.from_cli(args_list)                                                       │
│   259 │   cfg = om.merge(yaml_cfg, cli_cfg)                                                      │
│ ❱ 260 │   main(cfg)                                                                              │
│   261                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/scripts/train/train.py:249 in main                                              │
│                                                                                                  │
│   246 │   │   trainer.eval()                                                                     │
│   247 │                                                                                          │
│   248 │   print('Starting training...')                                                          │
│ ❱ 249 │   trainer.fit()                                                                          │
│   250 │                                                                                          │
│   251 │   print('Done.')                                                                         │
│   252                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1766 in fit       │
│                                                                                                  │
│   1763 │   │   │   self.state.scaler = ClosureGradScaler() if self._use_closures() else GradSca  │
│   1764 │   │                                                                                     │
│   1765 │   │   self.first_batch_complete = False                                                 │
│ ❱ 1766 │   │   self._train_loop()                                                                │
│   1767 │                                                                                         │
│   1768 │   def close(self):                                                                      │
│   1769 │   │   """Shutdown the trainer.                                                          │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1940 in           │
│ _train_loop                                                                                      │
│                                                                                                  │
│   1937 │   │   │   │   │   │   self.logger.log_metrics({'time/token': self.state.timestamp.toke  │
│   1938 │   │   │   │   │   │   self.logger.log_metrics({'time/token_in_epoch': self.state.times  │
│   1939 │   │   │   │   │                                                                         │
│ ❱ 1940 │   │   │   │   │   total_loss_dict = self._train_batch(use_grad_scaling)                 │
│   1941 │   │   │   │   │                                                                         │
│   1942 │   │   │   │   │   if use_grad_scaling:                                                  │
│   1943 │   │   │   │   │   │   self.state.scaler.update()                                        │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in           │
│ _train_batch                                                                                     │
│                                                                                                  │
│   2121 │   │   │   │   │   │   │   │   │   │   │   │      closure=lambda loss_dict=total_loss_d  │
│   2122 │   │   │   │   │   │   │   │   │   │   │   │      _train_microbatches(microbatches, los  │
│   2123 │   │   │   │   │   │   else:                                                             │
│ ❱ 2124 │   │   │   │   │   │   │   optimizer.step(closure=lambda **kwargs: self._train_microbat  │
│   2125 │   │   │   │   │   │   │   │   microbatches, total_loss_dict, **kwargs).item())          │
│   2126 │   │   │   │   else:                                                                     │
│   2127 │   │   │   │   │   self._train_microbatches(microbatches, total_loss_dict)               │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py:69 in wrapper                         │
│                                                                                                  │
│     66 │   │   │   │   instance = instance_ref()                                                 │
│     67 │   │   │   │   instance._step_count += 1                                                 │
│     68 │   │   │   │   wrapped = func.__get__(instance, cls)                                     │
│ ❱   69 │   │   │   │   return wrapped(*args, **kwargs)                                           │
│     70 │   │   │                                                                                 │
│     71 │   │   │   # Note that the returned function here is no longer a bound method,           │
│     72 │   │   │   # so attributes like `__func__` and `__self__` no longer exist.               │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/optim/optimizer.py:280 in wrapper                           │
│                                                                                                  │
│   277 │   │   │   │   │   │   │   raise RuntimeError(f"{func} must return None or a tuple of (   │
│   278 │   │   │   │   │   │   │   │   │   │   │      f"but got {result}.")                       │
│   279 │   │   │   │                                                                              │
│ ❱ 280 │   │   │   │   out = func(*args, **kwargs)                                                │
│   281 │   │   │   │   self._optimizer_step_code()                                                │
│   282 │   │   │   │                                                                              │
│   283 │   │   │   │   # call optimizer step post hooks                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/utils/_contextlib.py:115 in decorate_context                │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py:288  │
│ in step                                                                                          │
│                                                                                                  │
│   285 │   │   loss = None                                                                        │
│   286 │   │   if closure is not None:                                                            │
│   287 │   │   │   with torch.enable_grad():                                                      │
│ ❱ 288 │   │   │   │   loss = closure()                                                           │
│   289 │   │                                                                                      │
│   290 │   │   for group in self.param_groups:                                                    │
│   291 │   │   │   params_with_grad = []                                                          │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in <lambda>  │
│                                                                                                  │
│   2121 │   │   │   │   │   │   │   │   │   │   │   │      closure=lambda loss_dict=total_loss_d  │
│   2122 │   │   │   │   │   │   │   │   │   │   │   │      _train_microbatches(microbatches, los  │
│   2123 │   │   │   │   │   │   else:                                                             │
│ ❱ 2124 │   │   │   │   │   │   │   optimizer.step(closure=lambda **kwargs: self._train_microbat  │
│   2125 │   │   │   │   │   │   │   │   microbatches, total_loss_dict, **kwargs).item())          │
│   2126 │   │   │   │   else:                                                                     │
│   2127 │   │   │   │   │   self._train_microbatches(microbatches, total_loss_dict)               │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2222 in           │
│ _train_microbatches                                                                              │
│                                                                                                  │
│   2219 │   │   │                                                                                 │
│   2220 │   │   │   for microbatch_idx, self.state.batch in enumerate(microbatches):              │
│   2221 │   │   │   │   is_final_microbatch = microbatch_idx + 1 == len(microbatches)             │
│ ❱ 2222 │   │   │   │   microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_  │
│   2223 │   │   │   │                                                                             │
│   2224 │   │   │   │   # Aggregate each loss in microbatch_loss_dict into total_loss_dict        │
│   2225 │   │   │   │   for k, microbatch_loss in microbatch_loss_dict.items():                   │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2349 in           │
│ _train_microbatch                                                                                │
│                                                                                                  │
│   2346 │   │   │   else:                                                                         │
│   2347 │   │   │   │   # Scale loss based on the number of samples in the microbatch to maintai  │
│   2348 │   │   │   │   microbatch_loss.mul_(microbatch_num_samples / current_batch_size)         │
│ ❱ 2349 │   │   │   │   microbatch_loss.backward(create_graph=self._backwards_create_graph)       │
│   2350 │   │   │                                                                                 │
│   2351 │   │   │   self.engine.run_event(Event.AFTER_BACKWARD)                                   │
│   2352                                                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/_tensor.py:487 in backward                                  │
│                                                                                                  │
│    484 │   │   │   │   create_graph=create_graph,                                                │
│    485 │   │   │   │   inputs=inputs,                                                            │
│    486 │   │   │   )                                                                             │
│ ❱  487 │   │   torch.autograd.backward(                                                          │
│    488 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs                     │
│    489 │   │   )                                                                                 │
│    490                                                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/autograd/__init__.py:200 in backward                        │
│                                                                                                  │
│   197 │   # The reason we repeat same the comment below is that                                  │
│   198 │   # some Python versions print out the first line of a multi-line function               │
│   199 │   # calls in the traceback and some print out the last line                              │
│ ❱ 200 │   Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the bac   │
│   201 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                        │
│   202 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to ru   │
│   203                                                                                            │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/autograd/function.py:274 in apply                           │
│                                                                                                  │
│   271 │   │   │   │   │   │   │      "Function is not allowed. You should only implement one "   │
│   272 │   │   │   │   │   │   │      "of them.")                                                 │
│   273 │   │   user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn                    │
│ ❱ 274 │   │   return user_fn(self, *args)                                                        │
│   275 │                                                                                          │
│   276 │   def apply_jvp(self, *args):                                                            │
│   277 │   │   # _forward_cls is defined by derived class                                         │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:1865 in  │
│ backward                                                                                         │
│                                                                                                  │
│   1862 │   │   │   │   weight,                                                                   │
│   1863 │   │   │   │   weight_t_fp8,                                                             │
│   1864 │   │   │   │   fwd_scale_inverses,                                                       │
│ ❱ 1865 │   │   │   ) = ctx.saved_tensors                                                         │
│   1866 │   │   │                                                                                 │
│   1867 │   │   │   if ctx.ub_split_ag:                                                           │
│   1868 │   │   │   │   tp_world_size = get_distributed_world_size(ctx.tp_group)                  │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/utils/checkpoint.py:420 in unpack                           │
│                                                                                                  │
│   417 │   │   │   │   │    torch.cuda.amp.autocast(**gpu_autocast_kwargs), \                     │
│   418 │   │   │   │   │    torch.cpu.amp.autocast(**cpu_autocast_kwargs), \                      │
│   419 │   │   │   │   │    torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):   │
│ ❱ 420 │   │   │   │   │   _unused = function(*args, **kwargs)                                    │
│   421 │   │                                                                                      │
│   422 │   │   if x not in storage:                                                               │
│   423 │   │   │   raise RuntimeError(                                                            │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl                     │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /mnt/llm-foundry/llmfoundry/models/layers/blocks.py:110 in forward                               │
│                                                                                                  │
│   107 │   │   is_causal: bool = True,                                                            │
│   108 │   ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:                               │
│   109 │   │   a = self.norm_1(x)                                                                 │
│ ❱ 110 │   │   b, attn_weights, past_key_value = self.attn(                                       │
│   111 │   │   │   a,                                                                             │
│   112 │   │   │   past_key_value=past_key_value,                                                 │
│   113 │   │   │   attn_bias=attn_bias,                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl                     │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /mnt/llm-foundry/llmfoundry/models/layers/attention.py:423 in forward                            │
│                                                                                                  │
│   420 │   │   is_causal=True,                                                                    │
│   421 │   │   needs_weights=False,                                                               │
│   422 │   ):                                                                                     │
│ ❱ 423 │   │   qkv = self.Wqkv(x)                                                                 │
│   424 │   │                                                                                      │
│   425 │   │   if self.clip_qkv:                                                                  │
│   426 │   │   │   qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)                              │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl                     │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:2267 in  │
│ forward                                                                                          │
│                                                                                                  │
│   2264 │   │   │   │   │   │   │      produced)                                                  │
│   2265 │   │   """                                                                               │
│   2266 │   │                                                                                     │
│ ❱ 2267 │   │   with self.prepare_forward(inp, is_first_microbatch) as inp:                       │
│   2268 │   │   │   bias_tensor = (                                                               │
│   2269 │   │   │   │   bias if bias is not None                                                  │
│   2270 │   │   │   │   else self.bias if self.parameters_split is None                           │
│                                                                                                  │
│ /usr/lib/python3.10/contextlib.py:135 in __enter__                                               │
│                                                                                                  │
│   132 │   │   # they are only needed for recreation, which is not possible anymore               │
│   133 │   │   del self.args, self.kwds, self.func                                                │
│   134 │   │   try:                                                                               │
│ ❱ 135 │   │   │   return next(self.gen)                                                          │
│   136 │   │   except StopIteration:                                                              │
│   137 │   │   │   raise RuntimeError("generator didn't yield") from None                         │
│   138                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:632 in   │
│ prepare_forward                                                                                  │
│                                                                                                  │
│    629 │   │   │   │   │   self.fp8_meta["autocast_id_fwd_stack"].append(                        │
│    630 │   │   │   │   │   │   self.fp8_meta["autocast_id_fwd"]                                  │
│    631 │   │   │   │   │   )                                                                     │
│ ❱  632 │   │   │   │   │   add_amax_to_global_buffer(self.fp8_meta, forward=True)                │
│    633 │   │   │   │   self.fp8_meta["update_amax_and_scale_fwd"] = True                         │
│    634 │   │   │   else:                                                                         │
│    635 │   │   │   │   self.fp8_meta["update_amax_and_scale_fwd"] = False                        │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/fp8.py:134 in      │
│ add_amax_to_global_buffer                                                                        │
│                                                                                                  │
│   131 │   │   fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1            │
│   132 │                                                                                          │
│   133 │   # Catch incorrect fp8_autocast usage.                                                  │
│ ❱ 134 │   assert fp8_meta[buffer_position_key] == len(_global_fp8_buffer[buffer_key]) - 1, \     │
│   135 │   │   "Same module is being invoked more than once inside an `fp8_autocast` region whe   │
│   136 │   │   "FP8 with amax reduction. This behavior is currently unsupported. For more detai   │
│   137 │   │   "correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93."   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError: Same module is being invoked more than once inside an `fp8_autocast` region when using FP8 with amax reduction. This behavior is currently 
unsupported. For more details and correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93.

vchiley avatar Jun 02 '23 18:06 vchiley

Re. activation_checkpointing_reentrant: false, I believe this is the new style of ActCkpt that torch recommends and may even be default going forward. It also enabled us to do layer freezing if I remember correctly. But it's not necessary, and if there's some torch2/fp8 bug that requires activation_checkpointing_reentrant: true, I think that's fine.

Edit: actually re-read your comments, and I guess we are still blocked on ActCkpt + amp_fp8. Will read the logs carefully now...

abhi-mosaic avatar Jun 02 '23 19:06 abhi-mosaic

There seems to be a bug referenced in this PR (https://github.com/NVIDIA/TransformerEngine/pull/93) that was fixed in this PR (https://github.com/NVIDIA/TransformerEngine/pull/187) that is available on main but not any release.

Could you install TE @ main and try? 🙏

abhi-mosaic avatar Jun 02 '23 19:06 abhi-mosaic

Note: transformer_engine has its own ckpt util which might need to be integrated into composer (?) for fp8 to work with act ckpt???

vchiley avatar Jun 02 '23 19:06 vchiley

TE @ main requires flash-attn==1.0.6 flash-attn==1.0.6 has this issue

Solution: add --no-build-isolation to pip install

vchiley avatar Jun 02 '23 20:06 vchiley

installing TE from main (as suggested by @abhi-mosaic) makes our integrated act ckpt work; no need to integrate TE act ckpt activation_checkpointing_reentrant: false still broken

vchiley avatar Jun 02 '23 20:06 vchiley

CE per param Screenshot 2023-06-16 at 2 29 15 PM

TFLOPS per param Screenshot 2023-06-16 at 2 29 30 PM

Note: 3B model uses act chpt so its model TFLOPS is multiplied by 0.75.

slightly more here

vchiley avatar Jun 16 '23 21:06 vchiley