pytorch-summary icon indicating copy to clipboard operation
pytorch-summary copied to clipboard

Fix None Type error while using MultiHeadAttention

Open GCS-ZHN opened this issue 1 year ago • 0 comments

This PR is modified based on previous PR #165 by @cainmagi ,

Main change features:

  • automatical detect and filter not array like elements in forward output list/tuple/dict. For example, MultiHeadAttention module return a tuple which contain a NoneType value as a placeholder of attention weight.
  • If filtered output contain no element, raise a ValueError to notify user instead of original NoneType AtrributeError.
  • Replace -1 to batch_size in dict/list/tuple output shape because I believe it will be more properly.

GCS-ZHN avatar Oct 26 '22 06:10 GCS-ZHN