llm-foundry
llm-foundry copied to clipboard
adding te Linear for fp8 support
ran composer train/train.py train/yamls/pretrain/mpt-3b.yaml also with model.fc_type=te and precision=amp_fp8
Result:
torch: throughput/device/tokens_per_sec: 23.7k
te: throughput/device/tokens_per_sec: 23.7k
te with fp8: throughput/device/tokens_per_sec: 29.4k
Note there does seem to be this error when activation ckpt is enabled when activation_checkpointing_reentrant: false. If we set activation_checkpointing_reentrant: true, then act ckpt works fine without amp_fp8; with amp_fp8 the issue still persists.
(previously, circa summer 2022, activation_checkpointing_reentrant: true resulted in some difficulties which is why we set it to false; not sure if this is necessary still...)
ActCkpt error will be added to the comments.
(the error with amp_fp8 might be a composer impl of fp8 issue)
activation_checkpointing_reentrant: false actckpt error without amp_fp8
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/llm-foundry/scripts/train/train.py:254 in <module> │
│ │
│ 251 │ │ yaml_cfg = om.load(f) │
│ 252 │ cli_cfg = om.from_cli(args_list) │
│ 253 │ cfg = om.merge(yaml_cfg, cli_cfg) │
│ ❱ 254 │ main(cfg) │
│ 255 │
│ │
│ /mnt/llm-foundry/scripts/train/train.py:243 in main │
│ │
│ 240 │ │ trainer.eval() │
│ 241 │ │
│ 242 │ print('Starting training...') │
│ ❱ 243 │ trainer.fit() │
│ 244 │ │
│ 245 │ print('Done.') │
│ 246 │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1766 in fit │
│ │
│ 1763 │ │ │ self.state.scaler = ClosureGradScaler() if self._use_closures() else GradSca │
│ 1764 │ │ │
│ 1765 │ │ self.first_batch_complete = False │
│ ❱ 1766 │ │ self._train_loop() │
│ 1767 │ │
│ 1768 │ def close(self): │
│ 1769 │ │ """Shutdown the trainer. │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1940 in │
│ _train_loop │
│ │
│ 1937 │ │ │ │ │ │ self.logger.log_metrics({'time/token': self.state.timestamp.toke │
│ 1938 │ │ │ │ │ │ self.logger.log_metrics({'time/token_in_epoch': self.state.times │
│ 1939 │ │ │ │ │ │
│ ❱ 1940 │ │ │ │ │ total_loss_dict = self._train_batch(use_grad_scaling) │
│ 1941 │ │ │ │ │ │
│ 1942 │ │ │ │ │ if use_grad_scaling: │
│ 1943 │ │ │ │ │ │ self.state.scaler.update() │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in │
│ _train_batch │
│ │
│ 2121 │ │ │ │ │ │ │ │ │ │ │ │ closure=lambda loss_dict=total_loss_d │
│ 2122 │ │ │ │ │ │ │ │ │ │ │ │ _train_microbatches(microbatches, los │
│ 2123 │ │ │ │ │ │ else: │
│ ❱ 2124 │ │ │ │ │ │ │ optimizer.step(closure=lambda **kwargs: self._train_microbat │
│ 2125 │ │ │ │ │ │ │ │ microbatches, total_loss_dict, **kwargs).item()) │
│ 2126 │ │ │ │ else: │
│ 2127 │ │ │ │ │ self._train_microbatches(microbatches, total_loss_dict) │
│ │
│ /usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py:69 in wrapper │
│ │
│ 66 │ │ │ │ instance = instance_ref() │
│ 67 │ │ │ │ instance._step_count += 1 │
│ 68 │ │ │ │ wrapped = func.__get__(instance, cls) │
│ ❱ 69 │ │ │ │ return wrapped(*args, **kwargs) │
│ 70 │ │ │ │
│ 71 │ │ │ # Note that the returned function here is no longer a bound method, │
│ 72 │ │ │ # so attributes like `__func__` and `__self__` no longer exist. │
│ │
│ /usr/lib/python3/dist-packages/torch/optim/optimizer.py:280 in wrapper │
│ │
│ 277 │ │ │ │ │ │ │ raise RuntimeError(f"{func} must return None or a tuple of ( │
│ 278 │ │ │ │ │ │ │ │ │ │ │ f"but got {result}.") │
│ 279 │ │ │ │ │
│ ❱ 280 │ │ │ │ out = func(*args, **kwargs) │
│ 281 │ │ │ │ self._optimizer_step_code() │
│ 282 │ │ │ │ │
│ 283 │ │ │ │ # call optimizer step post hooks │
│ │
│ /usr/lib/python3/dist-packages/torch/utils/_contextlib.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py:288 │
│ in step │
│ │
│ 285 │ │ loss = None │
│ 286 │ │ if closure is not None: │
│ 287 │ │ │ with torch.enable_grad(): │
│ ❱ 288 │ │ │ │ loss = closure() │
│ 289 │ │ │
│ 290 │ │ for group in self.param_groups: │
│ 291 │ │ │ params_with_grad = [] │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in <lambda> │
│ │
│ 2121 │ │ │ │ │ │ │ │ │ │ │ │ closure=lambda loss_dict=total_loss_d │
│ 2122 │ │ │ │ │ │ │ │ │ │ │ │ _train_microbatches(microbatches, los │
│ 2123 │ │ │ │ │ │ else: │
│ ❱ 2124 │ │ │ │ │ │ │ optimizer.step(closure=lambda **kwargs: self._train_microbat │
│ 2125 │ │ │ │ │ │ │ │ microbatches, total_loss_dict, **kwargs).item()) │
│ 2126 │ │ │ │ else: │
│ 2127 │ │ │ │ │ self._train_microbatches(microbatches, total_loss_dict) │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2222 in │
│ _train_microbatches │
│ │
│ 2219 │ │ │ │
│ 2220 │ │ │ for microbatch_idx, self.state.batch in enumerate(microbatches): │
│ 2221 │ │ │ │ is_final_microbatch = microbatch_idx + 1 == len(microbatches) │
│ ❱ 2222 │ │ │ │ microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_ │
│ 2223 │ │ │ │ │
│ 2224 │ │ │ │ # Aggregate each loss in microbatch_loss_dict into total_loss_dict │
│ 2225 │ │ │ │ for k, microbatch_loss in microbatch_loss_dict.items(): │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2349 in │
│ _train_microbatch │
│ │
│ 2346 │ │ │ else: │
│ 2347 │ │ │ │ # Scale loss based on the number of samples in the microbatch to maintai │
│ 2348 │ │ │ │ microbatch_loss.mul_(microbatch_num_samples / current_batch_size) │
│ ❱ 2349 │ │ │ │ microbatch_loss.backward(create_graph=self._backwards_create_graph) │
│ 2350 │ │ │ │
│ 2351 │ │ │ self.engine.run_event(Event.AFTER_BACKWARD) │
│ 2352 │
│ │
│ /usr/lib/python3/dist-packages/torch/_tensor.py:487 in backward │
│ │
│ 484 │ │ │ │ create_graph=create_graph, │
│ 485 │ │ │ │ inputs=inputs, │
│ 486 │ │ │ ) │
│ ❱ 487 │ │ torch.autograd.backward( │
│ 488 │ │ │ self, gradient, retain_graph, create_graph, inputs=inputs │
│ 489 │ │ ) │
│ 490 │
│ │
│ /usr/lib/python3/dist-packages/torch/autograd/__init__.py:200 in backward │
│ │
│ 197 │ # The reason we repeat same the comment below is that │
│ 198 │ # some Python versions print out the first line of a multi-line function │
│ 199 │ # calls in the traceback and some print out the last line │
│ ❱ 200 │ Variable._execution_engine.run_backward( # Calls into the C++ engine to run the bac │
│ 201 │ │ tensors, grad_tensors_, retain_graph, create_graph, inputs, │
│ 202 │ │ allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to ru │
│ 203 │
│ │
│ /usr/lib/python3/dist-packages/torch/autograd/function.py:274 in apply │
│ │
│ 271 │ │ │ │ │ │ │ "Function is not allowed. You should only implement one " │
│ 272 │ │ │ │ │ │ │ "of them.") │
│ 273 │ │ user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn │
│ ❱ 274 │ │ return user_fn(self, *args) │
│ 275 │ │
│ 276 │ def apply_jvp(self, *args): │
│ 277 │ │ # _forward_cls is defined by derived class │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:1865 in │
│ backward │
│ │
│ 1862 │ │ │ │ weight, │
│ 1863 │ │ │ │ weight_t_fp8, │
│ 1864 │ │ │ │ fwd_scale_inverses, │
│ ❱ 1865 │ │ │ ) = ctx.saved_tensors │
│ 1866 │ │ │ │
│ 1867 │ │ │ if ctx.ub_split_ag: │
│ 1868 │ │ │ │ tp_world_size = get_distributed_world_size(ctx.tp_group) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: !grad_accumulator_.expired() INTERNAL ASSERT FAILED at "../torch/csrc/autograd/saved_variable.cpp":226, please report a bug to PyTorch. No
grad accumulator for a saved leaf
actckpt error with amp_fp8 (with activation_checkpointing_reentrant: false or true)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/llm-foundry/scripts/train/train.py:260 in <module> │
│ │
│ 257 │ │ yaml_cfg = om.load(f) │
│ 258 │ cli_cfg = om.from_cli(args_list) │
│ 259 │ cfg = om.merge(yaml_cfg, cli_cfg) │
│ ❱ 260 │ main(cfg) │
│ 261 │
│ │
│ /mnt/llm-foundry/scripts/train/train.py:249 in main │
│ │
│ 246 │ │ trainer.eval() │
│ 247 │ │
│ 248 │ print('Starting training...') │
│ ❱ 249 │ trainer.fit() │
│ 250 │ │
│ 251 │ print('Done.') │
│ 252 │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1766 in fit │
│ │
│ 1763 │ │ │ self.state.scaler = ClosureGradScaler() if self._use_closures() else GradSca │
│ 1764 │ │ │
│ 1765 │ │ self.first_batch_complete = False │
│ ❱ 1766 │ │ self._train_loop() │
│ 1767 │ │
│ 1768 │ def close(self): │
│ 1769 │ │ """Shutdown the trainer. │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1940 in │
│ _train_loop │
│ │
│ 1937 │ │ │ │ │ │ self.logger.log_metrics({'time/token': self.state.timestamp.toke │
│ 1938 │ │ │ │ │ │ self.logger.log_metrics({'time/token_in_epoch': self.state.times │
│ 1939 │ │ │ │ │ │
│ ❱ 1940 │ │ │ │ │ total_loss_dict = self._train_batch(use_grad_scaling) │
│ 1941 │ │ │ │ │ │
│ 1942 │ │ │ │ │ if use_grad_scaling: │
│ 1943 │ │ │ │ │ │ self.state.scaler.update() │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in │
│ _train_batch │
│ │
│ 2121 │ │ │ │ │ │ │ │ │ │ │ │ closure=lambda loss_dict=total_loss_d │
│ 2122 │ │ │ │ │ │ │ │ │ │ │ │ _train_microbatches(microbatches, los │
│ 2123 │ │ │ │ │ │ else: │
│ ❱ 2124 │ │ │ │ │ │ │ optimizer.step(closure=lambda **kwargs: self._train_microbat │
│ 2125 │ │ │ │ │ │ │ │ microbatches, total_loss_dict, **kwargs).item()) │
│ 2126 │ │ │ │ else: │
│ 2127 │ │ │ │ │ self._train_microbatches(microbatches, total_loss_dict) │
│ │
│ /usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py:69 in wrapper │
│ │
│ 66 │ │ │ │ instance = instance_ref() │
│ 67 │ │ │ │ instance._step_count += 1 │
│ 68 │ │ │ │ wrapped = func.__get__(instance, cls) │
│ ❱ 69 │ │ │ │ return wrapped(*args, **kwargs) │
│ 70 │ │ │ │
│ 71 │ │ │ # Note that the returned function here is no longer a bound method, │
│ 72 │ │ │ # so attributes like `__func__` and `__self__` no longer exist. │
│ │
│ /usr/lib/python3/dist-packages/torch/optim/optimizer.py:280 in wrapper │
│ │
│ 277 │ │ │ │ │ │ │ raise RuntimeError(f"{func} must return None or a tuple of ( │
│ 278 │ │ │ │ │ │ │ │ │ │ │ f"but got {result}.") │
│ 279 │ │ │ │ │
│ ❱ 280 │ │ │ │ out = func(*args, **kwargs) │
│ 281 │ │ │ │ self._optimizer_step_code() │
│ 282 │ │ │ │ │
│ 283 │ │ │ │ # call optimizer step post hooks │
│ │
│ /usr/lib/python3/dist-packages/torch/utils/_contextlib.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py:288 │
│ in step │
│ │
│ 285 │ │ loss = None │
│ 286 │ │ if closure is not None: │
│ 287 │ │ │ with torch.enable_grad(): │
│ ❱ 288 │ │ │ │ loss = closure() │
│ 289 │ │ │
│ 290 │ │ for group in self.param_groups: │
│ 291 │ │ │ params_with_grad = [] │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in <lambda> │
│ │
│ 2121 │ │ │ │ │ │ │ │ │ │ │ │ closure=lambda loss_dict=total_loss_d │
│ 2122 │ │ │ │ │ │ │ │ │ │ │ │ _train_microbatches(microbatches, los │
│ 2123 │ │ │ │ │ │ else: │
│ ❱ 2124 │ │ │ │ │ │ │ optimizer.step(closure=lambda **kwargs: self._train_microbat │
│ 2125 │ │ │ │ │ │ │ │ microbatches, total_loss_dict, **kwargs).item()) │
│ 2126 │ │ │ │ else: │
│ 2127 │ │ │ │ │ self._train_microbatches(microbatches, total_loss_dict) │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2222 in │
│ _train_microbatches │
│ │
│ 2219 │ │ │ │
│ 2220 │ │ │ for microbatch_idx, self.state.batch in enumerate(microbatches): │
│ 2221 │ │ │ │ is_final_microbatch = microbatch_idx + 1 == len(microbatches) │
│ ❱ 2222 │ │ │ │ microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_ │
│ 2223 │ │ │ │ │
│ 2224 │ │ │ │ # Aggregate each loss in microbatch_loss_dict into total_loss_dict │
│ 2225 │ │ │ │ for k, microbatch_loss in microbatch_loss_dict.items(): │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2349 in │
│ _train_microbatch │
│ │
│ 2346 │ │ │ else: │
│ 2347 │ │ │ │ # Scale loss based on the number of samples in the microbatch to maintai │
│ 2348 │ │ │ │ microbatch_loss.mul_(microbatch_num_samples / current_batch_size) │
│ ❱ 2349 │ │ │ │ microbatch_loss.backward(create_graph=self._backwards_create_graph) │
│ 2350 │ │ │ │
│ 2351 │ │ │ self.engine.run_event(Event.AFTER_BACKWARD) │
│ 2352 │
│ │
│ /usr/lib/python3/dist-packages/torch/_tensor.py:487 in backward │
│ │
│ 484 │ │ │ │ create_graph=create_graph, │
│ 485 │ │ │ │ inputs=inputs, │
│ 486 │ │ │ ) │
│ ❱ 487 │ │ torch.autograd.backward( │
│ 488 │ │ │ self, gradient, retain_graph, create_graph, inputs=inputs │
│ 489 │ │ ) │
│ 490 │
│ │
│ /usr/lib/python3/dist-packages/torch/autograd/__init__.py:200 in backward │
│ │
│ 197 │ # The reason we repeat same the comment below is that │
│ 198 │ # some Python versions print out the first line of a multi-line function │
│ 199 │ # calls in the traceback and some print out the last line │
│ ❱ 200 │ Variable._execution_engine.run_backward( # Calls into the C++ engine to run the bac │
│ 201 │ │ tensors, grad_tensors_, retain_graph, create_graph, inputs, │
│ 202 │ │ allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to ru │
│ 203 │
│ │
│ /usr/lib/python3/dist-packages/torch/autograd/function.py:274 in apply │
│ │
│ 271 │ │ │ │ │ │ │ "Function is not allowed. You should only implement one " │
│ 272 │ │ │ │ │ │ │ "of them.") │
│ 273 │ │ user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn │
│ ❱ 274 │ │ return user_fn(self, *args) │
│ 275 │ │
│ 276 │ def apply_jvp(self, *args): │
│ 277 │ │ # _forward_cls is defined by derived class │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:1865 in │
│ backward │
│ │
│ 1862 │ │ │ │ weight, │
│ 1863 │ │ │ │ weight_t_fp8, │
│ 1864 │ │ │ │ fwd_scale_inverses, │
│ ❱ 1865 │ │ │ ) = ctx.saved_tensors │
│ 1866 │ │ │ │
│ 1867 │ │ │ if ctx.ub_split_ag: │
│ 1868 │ │ │ │ tp_world_size = get_distributed_world_size(ctx.tp_group) │
│ │
│ /usr/lib/python3/dist-packages/torch/utils/checkpoint.py:420 in unpack │
│ │
│ 417 │ │ │ │ │ torch.cuda.amp.autocast(**gpu_autocast_kwargs), \ │
│ 418 │ │ │ │ │ torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ │
│ 419 │ │ │ │ │ torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): │
│ ❱ 420 │ │ │ │ │ _unused = function(*args, **kwargs) │
│ 421 │ │ │
│ 422 │ │ if x not in storage: │
│ 423 │ │ │ raise RuntimeError( │
│ │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /mnt/llm-foundry/llmfoundry/models/layers/blocks.py:110 in forward │
│ │
│ 107 │ │ is_causal: bool = True, │
│ 108 │ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: │
│ 109 │ │ a = self.norm_1(x) │
│ ❱ 110 │ │ b, attn_weights, past_key_value = self.attn( │
│ 111 │ │ │ a, │
│ 112 │ │ │ past_key_value=past_key_value, │
│ 113 │ │ │ attn_bias=attn_bias, │
│ │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /mnt/llm-foundry/llmfoundry/models/layers/attention.py:423 in forward │
│ │
│ 420 │ │ is_causal=True, │
│ 421 │ │ needs_weights=False, │
│ 422 │ ): │
│ ❱ 423 │ │ qkv = self.Wqkv(x) │
│ 424 │ │ │
│ 425 │ │ if self.clip_qkv: │
│ 426 │ │ │ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) │
│ │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:2267 in │
│ forward │
│ │
│ 2264 │ │ │ │ │ │ │ produced) │
│ 2265 │ │ """ │
│ 2266 │ │ │
│ ❱ 2267 │ │ with self.prepare_forward(inp, is_first_microbatch) as inp: │
│ 2268 │ │ │ bias_tensor = ( │
│ 2269 │ │ │ │ bias if bias is not None │
│ 2270 │ │ │ │ else self.bias if self.parameters_split is None │
│ │
│ /usr/lib/python3.10/contextlib.py:135 in __enter__ │
│ │
│ 132 │ │ # they are only needed for recreation, which is not possible anymore │
│ 133 │ │ del self.args, self.kwds, self.func │
│ 134 │ │ try: │
│ ❱ 135 │ │ │ return next(self.gen) │
│ 136 │ │ except StopIteration: │
│ 137 │ │ │ raise RuntimeError("generator didn't yield") from None │
│ 138 │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:632 in │
│ prepare_forward │
│ │
│ 629 │ │ │ │ │ self.fp8_meta["autocast_id_fwd_stack"].append( │
│ 630 │ │ │ │ │ │ self.fp8_meta["autocast_id_fwd"] │
│ 631 │ │ │ │ │ ) │
│ ❱ 632 │ │ │ │ │ add_amax_to_global_buffer(self.fp8_meta, forward=True) │
│ 633 │ │ │ │ self.fp8_meta["update_amax_and_scale_fwd"] = True │
│ 634 │ │ │ else: │
│ 635 │ │ │ │ self.fp8_meta["update_amax_and_scale_fwd"] = False │
│ │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/fp8.py:134 in │
│ add_amax_to_global_buffer │
│ │
│ 131 │ │ fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1 │
│ 132 │ │
│ 133 │ # Catch incorrect fp8_autocast usage. │
│ ❱ 134 │ assert fp8_meta[buffer_position_key] == len(_global_fp8_buffer[buffer_key]) - 1, \ │
│ 135 │ │ "Same module is being invoked more than once inside an `fp8_autocast` region whe │
│ 136 │ │ "FP8 with amax reduction. This behavior is currently unsupported. For more detai │
│ 137 │ │ "correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93." │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError: Same module is being invoked more than once inside an `fp8_autocast` region when using FP8 with amax reduction. This behavior is currently
unsupported. For more details and correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93.
Re. activation_checkpointing_reentrant: false, I believe this is the new style of ActCkpt that torch recommends and may even be default going forward. It also enabled us to do layer freezing if I remember correctly. But it's not necessary, and if there's some torch2/fp8 bug that requires activation_checkpointing_reentrant: true, I think that's fine.
Edit: actually re-read your comments, and I guess we are still blocked on ActCkpt + amp_fp8. Will read the logs carefully now...
There seems to be a bug referenced in this PR (https://github.com/NVIDIA/TransformerEngine/pull/93) that was fixed in this PR (https://github.com/NVIDIA/TransformerEngine/pull/187) that is available on main but not any release.
Could you install TE @ main and try? 🙏
Note: transformer_engine has its own ckpt util which might need to be integrated into composer (?) for fp8 to work with act ckpt???
TE @ main requires flash-attn==1.0.6 flash-attn==1.0.6 has this issue
Solution: add --no-build-isolation to pip install
installing TE from main (as suggested by @abhi-mosaic) makes our integrated act ckpt work; no need to integrate TE act ckpt
activation_checkpointing_reentrant: false still broken
CE per param
TFLOPS per param
Note: 3B model uses act chpt so its model TFLOPS is multiplied by 0.75.
slightly more here