DeepLearningExamples icon indicating copy to clipboard operation
DeepLearningExamples copied to clipboard

[SE(3)-Transformer] Output embeddings collapse

Open anton-bushuiev opened this issue 11 months ago • 0 comments

Related to Model/Framework(s)

  • SE(3)-Transformer

Describe the bug I train SE(3)-Transformer to predict type-0 node embeddings. I notice that with multiple training hyper-parameter setups I get the same following error: after N epochs the model collapses to predict only full-NaN or sometimes full-0 features. @milesial, do you please have any pointers what may be the issue?

# After some N training epochs in training loop:
h = self.se3transformer(data_dgl, node_feats, edge_feats, all_bases)
h = h['0'].squeeze(-1)
if torch.isnan(h).any():
    for i in range(h.shape[0]):
        print(f'h[{i}]', h[i, :])
image

Pasted Graphic 2

Complete hyper-parameter setup for the model in Hydra:

se3transformer:
    _target_: se3_transformer.model.SE3Transformer
    fiber_in:
      _target_: se3_transformer.model.fiber.Fiber
      _convert_: all
      structure:
        0: 20
        1: 1
    fiber_hidden:
      _target_: se3_transformer.model.fiber.Fiber.create
      num_degrees: 2
      num_channels: 32
    fiber_out:
      _target_: se3_transformer.model.fiber.Fiber
      _convert_: all
      structure:
        0: 512
    num_layers: 7
    num_heads: 8
    channels_div: 2
    norm: true
    use_layer_norm: true

I use default Adam with lr=1e-3 within PyToch Lightning.

Expected behavior No complete NaN or zero output type-0 features.

Environment

  • Custom environment with Python 3.9.0 and torch==1.13.0+cu116
  • GPUs in the system: 8x NVIDIA A100-SXM4-40GB
  • CUDA driver version: 535.54.03 (CUDA Version: 12.2)

anton-bushuiev avatar Aug 20 '23 15:08 anton-bushuiev