DeepLearningExamples
DeepLearningExamples copied to clipboard
[SE(3)-Transformer] Output embeddings collapse
Related to Model/Framework(s)
- SE(3)-Transformer
Describe the bug I train SE(3)-Transformer to predict type-0 node embeddings. I notice that with multiple training hyper-parameter setups I get the same following error: after N epochs the model collapses to predict only full-NaN or sometimes full-0 features. @milesial, do you please have any pointers what may be the issue?
# After some N training epochs in training loop:
h = self.se3transformer(data_dgl, node_feats, edge_feats, all_bases)
h = h['0'].squeeze(-1)
if torch.isnan(h).any():
for i in range(h.shape[0]):
print(f'h[{i}]', h[i, :])
Complete hyper-parameter setup for the model in Hydra:
se3transformer:
_target_: se3_transformer.model.SE3Transformer
fiber_in:
_target_: se3_transformer.model.fiber.Fiber
_convert_: all
structure:
0: 20
1: 1
fiber_hidden:
_target_: se3_transformer.model.fiber.Fiber.create
num_degrees: 2
num_channels: 32
fiber_out:
_target_: se3_transformer.model.fiber.Fiber
_convert_: all
structure:
0: 512
num_layers: 7
num_heads: 8
channels_div: 2
norm: true
use_layer_norm: true
I use default Adam with lr=1e-3 within PyToch Lightning.
Expected behavior No complete NaN or zero output type-0 features.
Environment
- Custom environment with Python 3.9.0 and torch==1.13.0+cu116
- GPUs in the system: 8x NVIDIA A100-SXM4-40GB
- CUDA driver version: 535.54.03 (CUDA Version: 12.2)