TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

'Parameter' object has no attribute 'main_grad'

Open rahul003 opened this issue 2 years ago • 3 comments

Any idea what could be going wrong with fuse_wgrad_accumulation?

[1,0]<stderr>:│                                                                                                  │
[1,0]<stderr>:│ /opt/conda/lib/python3.10/site-packages/transformer_engine/pytorch/module/layernorm_mlp.py:640   │
[1,0]<stderr>:│ in backward                                                                                      │
[1,0]<stderr>:│                                                                                                  │
[1,0]<stderr>:│    637 │   │   │   │   │   │   grad=True,                                                        │
[1,0]<stderr>:│    638 │   │   │   │   │   │   use_bias=ctx.use_fc2_bias,  [1,0]<stderr>:                                      │
[1,0]<stderr>:│    639 │   │   │   │   │   │   accumulate=accumulate_wgrad_into_param_main_grad,                 │
[1,0]<stderr>:│ ❱  640 │   │   │   │   │   │   out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else Non  │
[1,0]<stderr>:│    641 │   │   │   │   │   )                                                                     │
[1,0]<stderr>:│    642 │   │   │   │                                                                             │
[1,0]<stderr>:│    643 │   │   │   │   if ctx.bias_gelu_nvfusion:                                                │
[1,0]<stderr>:╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
[1,0]<stderr>:AttributeError: 'Parameter' object has no attribute 'main_grad'

rahul003 avatar Jun 19 '23 20:06 rahul003

We're trying to use TransformerEngine's TE layer in our own model

rahul003 avatar Jun 19 '23 20:06 rahul003

The main_grad attribute on the weights needs to be a pre-allocated buffer outside Transformer Engine for gradient accumulation fusion optimization. The pytorch.Linear documentation in TE explains this.

ksivaman avatar Jun 20 '23 17:06 ksivaman

More concrete docs for this are available here.

crclark avatar Dec 22 '23 15:12 crclark